mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-07 17:50:35 +00:00
Compare commits
70 Commits
jacob/curr
...
langchain=
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d0c1d2dc9 | ||
|
|
a7296bddc2 | ||
|
|
c9473367b1 | ||
|
|
f77659463a | ||
|
|
ccdaf14eff | ||
|
|
cacdf96f9c | ||
|
|
36ee083753 | ||
|
|
e8a21146d3 | ||
|
|
a0958c0607 | ||
|
|
620b118c70 | ||
|
|
888fbc07b5 | ||
|
|
ab2d7821a7 | ||
|
|
6fc7610b1c | ||
|
|
0da5078cad | ||
|
|
d0728b0ba0 | ||
|
|
9224027e45 | ||
|
|
5c3e2612da | ||
|
|
65321bf975 | ||
|
|
2b7d1cdd2f | ||
|
|
a653b209ba | ||
|
|
f071581aea | ||
|
|
f0a7581b50 | ||
|
|
474b88326f | ||
|
|
bdc03997c9 | ||
|
|
3f1cf00d97 | ||
|
|
6b47c7361e | ||
|
|
7677ceea60 | ||
|
|
aee55eda39 | ||
|
|
d09dda5a08 | ||
|
|
12950cc602 | ||
|
|
e8ee781a42 | ||
|
|
02e71cebed | ||
|
|
259d4d2029 | ||
|
|
3aed74a6fc | ||
|
|
13b0d7ec8f | ||
|
|
71cd6e6feb | ||
|
|
99054e19eb | ||
|
|
7a1321e2f9 | ||
|
|
cb5031f22f | ||
|
|
f1618ec540 | ||
|
|
8d82a0d483 | ||
|
|
0a1e475a30 | ||
|
|
6166ea67a8 | ||
|
|
d77d9bfc00 | ||
|
|
aa3e3cfa40 | ||
|
|
14ba1d4b45 | ||
|
|
18da9f5e59 | ||
|
|
d3a2b9fae0 | ||
|
|
7014d07cab | ||
|
|
35784d1c33 | ||
|
|
8858846607 | ||
|
|
ffe6ca986e | ||
|
|
7790d67f94 | ||
|
|
1132fb801b | ||
|
|
1d37aa8403 | ||
|
|
cb95198398 | ||
|
|
d002fa902f | ||
|
|
8d100c58de | ||
|
|
5fd1e67808 | ||
|
|
eeb996034b | ||
|
|
03fba07d15 | ||
|
|
c481a2715d | ||
|
|
8ee8ca7c83 | ||
|
|
4121d4151f | ||
|
|
bd18faa2a0 | ||
|
|
f1f1f75782 | ||
|
|
4ba14adec6 | ||
|
|
457677c1b7 | ||
|
|
8327925ab7 | ||
|
|
122e80e04d |
1
.github/scripts/get_min_versions.py
vendored
1
.github/scripts/get_min_versions.py
vendored
@@ -9,6 +9,7 @@ MIN_VERSION_LIBS = [
|
||||
"langchain-community",
|
||||
"langchain",
|
||||
"langchain-text-splitters",
|
||||
"SQLAlchemy",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -21,14 +21,6 @@ jobs:
|
||||
run:
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.8"
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
name: "poetry run pytest -m compile tests/integration_tests #${{ inputs.python-version }}"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
4
.github/workflows/_test_doc_imports.yml
vendored
4
.github/workflows/_test_doc_imports.yml
vendored
@@ -14,10 +14,6 @@ env:
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.12"
|
||||
name: "check doc imports #${{ inputs.python-version }}"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
@@ -78,7 +78,7 @@ def _load_module_members(module_path: str, namespace: str) -> ModuleMembers:
|
||||
continue
|
||||
|
||||
if inspect.isclass(type_):
|
||||
# The clasification of the class is used to select a template
|
||||
# The type of the class is used to select a template
|
||||
# for the object when rendering the documentation.
|
||||
# See `templates` directory for defined templates.
|
||||
# This is a hacky solution to distinguish between different
|
||||
|
||||
@@ -821,7 +821,7 @@ We recommend this method as a starting point when working with structured output
|
||||
- If multiple underlying techniques are supported, you can supply a `method` parameter to
|
||||
[toggle which one is used](/docs/how_to/structured_output/#advanced-specifying-the-method-for-structuring-outputs).
|
||||
|
||||
You may want or need to use other techiniques if:
|
||||
You may want or need to use other techniques if:
|
||||
|
||||
- The chat model you are using does not support tool calling.
|
||||
- You are working with very complex schemas and the model is having trouble generating outputs that conform.
|
||||
|
||||
@@ -33,6 +33,8 @@ Some examples include:
|
||||
|
||||
- [Build a Simple LLM Application with LCEL](/docs/tutorials/llm_chain/)
|
||||
- [Build a Retrieval Augmented Generation (RAG) App](/docs/tutorials/rag/)
|
||||
|
||||
A good structural rule of thumb is to follow the structure of this [example from Numpy](https://numpy.org/numpy-tutorials/content/tutorial-svd.html).
|
||||
|
||||
Here are some high-level tips on writing a good tutorial:
|
||||
|
||||
|
||||
@@ -15,6 +15,12 @@
|
||||
"\n",
|
||||
"Make sure you have the integration packages installed for any model providers you want to support. E.g. you should have `langchain-openai` installed to init an OpenAI model.\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
":::info Requires ``langchain >= 0.2.8``\n",
|
||||
"\n",
|
||||
"This functionality was added in ``langchain-core == 0.2.8``. Please make sure your package is up to date.\n",
|
||||
"\n",
|
||||
":::"
|
||||
]
|
||||
},
|
||||
@@ -25,7 +31,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -qU langchain langchain-openai langchain-anthropic langchain-google-vertexai"
|
||||
"%pip install -qU langchain>=0.2.8 langchain-openai langchain-anthropic langchain-google-vertexai"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -76,32 +82,6 @@
|
||||
"print(\"Gemini 1.5: \" + gemini_15.invoke(\"what's your name\").content + \"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fff9a4c8-b6ee-4a1a-8d3d-0ecaa312d4ed",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Simple config example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "75c25d39-bf47-4b51-a6c6-64d9c572bfd6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"user_config = {\n",
|
||||
" \"model\": \"...user-specified...\",\n",
|
||||
" \"model_provider\": \"...user-specified...\",\n",
|
||||
" \"temperature\": 0,\n",
|
||||
" \"max_tokens\": 1000,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"llm = init_chat_model(**user_config)\n",
|
||||
"llm.invoke(\"what's your name\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f811f219-5e78-4b62-b495-915d52a22532",
|
||||
@@ -125,12 +105,215 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "da07b5c0-d2e6-42e4-bfcd-2efcfaae6221",
|
||||
"cell_type": "markdown",
|
||||
"id": "476a44db-c50d-4846-951d-0f1c9ba8bbaa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"## Creating a configurable model\n",
|
||||
"\n",
|
||||
"You can also create a runtime-configurable model by specifying `configurable_fields`. If you don't specify a `model` value, then \"model\" and \"model_provider\" be configurable by default."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "6c037f27-12d7-4e83-811e-4245c0e3ba58",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"I'm an AI language model created by OpenAI, and I don't have a personal name. You can call me Assistant or any other name you prefer! How can I assist you today?\", response_metadata={'token_usage': {'completion_tokens': 37, 'prompt_tokens': 11, 'total_tokens': 48}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_d576307f90', 'finish_reason': 'stop', 'logprobs': None}, id='run-5428ab5c-b5c0-46de-9946-5d4ca40dbdc8-0', usage_metadata={'input_tokens': 11, 'output_tokens': 37, 'total_tokens': 48})"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"configurable_model = init_chat_model(temperature=0)\n",
|
||||
"\n",
|
||||
"configurable_model.invoke(\n",
|
||||
" \"what's your name\", config={\"configurable\": {\"model\": \"gpt-4o\"}}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "321e3036-abd2-4e1f-bcc6-606efd036954",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"My name is Claude. It's nice to meet you!\", response_metadata={'id': 'msg_012XvotUJ3kGLXJUWKBVxJUi', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 11, 'output_tokens': 15}}, id='run-1ad1eefe-f1c6-4244-8bc6-90e2cb7ee554-0', usage_metadata={'input_tokens': 11, 'output_tokens': 15, 'total_tokens': 26})"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"configurable_model.invoke(\n",
|
||||
" \"what's your name\", config={\"configurable\": {\"model\": \"claude-3-5-sonnet-20240620\"}}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7f3b3d4a-4066-45e4-8297-ea81ac8e70b7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Configurable model with default values\n",
|
||||
"\n",
|
||||
"We can create a configurable model with default model values, specify which parameters are configurable, and add prefixes to configurable params:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "814a2289-d0db-401e-b555-d5116112b413",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"I'm an AI language model created by OpenAI, and I don't have a personal name. You can call me Assistant or any other name you prefer! How can I assist you today?\", response_metadata={'token_usage': {'completion_tokens': 37, 'prompt_tokens': 11, 'total_tokens': 48}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_ce0793330f', 'finish_reason': 'stop', 'logprobs': None}, id='run-3923e328-7715-4cd6-b215-98e4b6bf7c9d-0', usage_metadata={'input_tokens': 11, 'output_tokens': 37, 'total_tokens': 48})"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"first_llm = init_chat_model(\n",
|
||||
" model=\"gpt-4o\",\n",
|
||||
" temperature=0,\n",
|
||||
" configurable_fields=(\"model\", \"model_provider\", \"temperature\", \"max_tokens\"),\n",
|
||||
" config_prefix=\"first\", # useful when you have a chain with multiple models\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"first_llm.invoke(\"what's your name\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "6c8755ba-c001-4f5a-a497-be3f1db83244",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"My name is Claude. It's nice to meet you!\", response_metadata={'id': 'msg_01RyYR64DoMPNCfHeNnroMXm', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 11, 'output_tokens': 15}}, id='run-22446159-3723-43e6-88df-b84797e7751d-0', usage_metadata={'input_tokens': 11, 'output_tokens': 15, 'total_tokens': 26})"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"first_llm.invoke(\n",
|
||||
" \"what's your name\",\n",
|
||||
" config={\n",
|
||||
" \"configurable\": {\n",
|
||||
" \"first_model\": \"claude-3-5-sonnet-20240620\",\n",
|
||||
" \"first_temperature\": 0.5,\n",
|
||||
" \"first_max_tokens\": 100,\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0072b1a3-7e44-4b4e-8b07-efe1ba91a689",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Using a configurable model declaratively\n",
|
||||
"\n",
|
||||
"We can call declarative operations like `bind_tools`, `with_structured_output`, `with_configurable`, etc. on a configurable model and chain a configurable model in the same way that we would a regularly instantiated chat model object."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "067dabee-1050-4110-ae24-c48eba01e13b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'name': 'GetPopulation',\n",
|
||||
" 'args': {'location': 'Los Angeles, CA'},\n",
|
||||
" 'id': 'call_sYT3PFMufHGWJD32Hi2CTNUP'},\n",
|
||||
" {'name': 'GetPopulation',\n",
|
||||
" 'args': {'location': 'New York, NY'},\n",
|
||||
" 'id': 'call_j1qjhxRnD3ffQmRyqjlI1Lnk'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GetWeather(BaseModel):\n",
|
||||
" \"\"\"Get the current weather in a given location\"\"\"\n",
|
||||
"\n",
|
||||
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GetPopulation(BaseModel):\n",
|
||||
" \"\"\"Get the current population in a given location\"\"\"\n",
|
||||
"\n",
|
||||
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"llm = init_chat_model(temperature=0)\n",
|
||||
"llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])\n",
|
||||
"\n",
|
||||
"llm_with_tools.invoke(\n",
|
||||
" \"what's bigger in 2024 LA or NYC\", config={\"configurable\": {\"model\": \"gpt-4o\"}}\n",
|
||||
").tool_calls"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "e57dfe9f-cd24-4e37-9ce9-ccf8daf78f89",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'name': 'GetPopulation',\n",
|
||||
" 'args': {'location': 'Los Angeles, CA'},\n",
|
||||
" 'id': 'toolu_01CxEHxKtVbLBrvzFS7GQ5xR'},\n",
|
||||
" {'name': 'GetPopulation',\n",
|
||||
" 'args': {'location': 'New York City, NY'},\n",
|
||||
" 'id': 'toolu_013A79qt5toWSsKunFBDZd5S'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm_with_tools.invoke(\n",
|
||||
" \"what's bigger in 2024 LA or NYC\",\n",
|
||||
" config={\"configurable\": {\"model\": \"claude-3-5-sonnet-20240620\"}},\n",
|
||||
").tool_calls"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -149,7 +332,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -48,20 +48,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"id": "40ed76a2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33mWARNING: You are using pip version 22.0.4; however, version 24.0 is available.\n",
|
||||
"You should consider upgrading via the '/Users/jacoblee/.pyenv/versions/3.10.5/bin/python -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n",
|
||||
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --upgrade --quiet langchain langchain-openai\n",
|
||||
"\n",
|
||||
|
||||
@@ -180,7 +180,7 @@
|
||||
"id": "32b1a992-8997-4c98-8eb2-c9fe9431b799",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Alternatively, we can add typing information via [Runnable.with_types](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.base.Runnable.html#langchain_core.runnables.base.Runnable.with_types):"
|
||||
"Alternatively, the schema can be fully specified by directly passing the desired [args_schema](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html#langchain_core.tools.BaseTool.args_schema) for the tool:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -190,10 +190,18 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"as_tool = runnable.with_types(input_type=Args).as_tool(\n",
|
||||
" name=\"My tool\",\n",
|
||||
" description=\"Explanation of when to use tool.\",\n",
|
||||
")"
|
||||
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GSchema(BaseModel):\n",
|
||||
" \"\"\"Apply a function to an integer and list of integers.\"\"\"\n",
|
||||
"\n",
|
||||
" a: int = Field(..., description=\"Integer\")\n",
|
||||
" b: List[int] = Field(..., description=\"List of ints\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"runnable = RunnableLambda(g)\n",
|
||||
"as_tool = runnable.as_tool(GSchema)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -131,7 +131,7 @@
|
||||
"source": [
|
||||
"## Base Chat Model\n",
|
||||
"\n",
|
||||
"Let's implement a chat model that echoes back the first `n` characetrs of the last message in the prompt!\n",
|
||||
"Let's implement a chat model that echoes back the first `n` characters of the last message in the prompt!\n",
|
||||
"\n",
|
||||
"To do so, we will inherit from `BaseChatModel` and we'll need to implement the following:\n",
|
||||
"\n",
|
||||
|
||||
@@ -16,13 +16,15 @@
|
||||
"| args_schema | Pydantic BaseModel | Optional but recommended, can be used to provide more information (e.g., few-shot examples) or validation for expected parameters |\n",
|
||||
"| return_direct | boolean | Only relevant for agents. When True, after invoking the given tool, the agent will stop and return the result direcly to the user. |\n",
|
||||
"\n",
|
||||
"LangChain provides 3 ways to create tools:\n",
|
||||
"LangChain supports the creation of tools from:\n",
|
||||
"\n",
|
||||
"1. Using [@tool decorator](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html#langchain_core.tools.tool) -- the simplest way to define a custom tool.\n",
|
||||
"2. Using [StructuredTool.from_function](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.StructuredTool.html#langchain_core.tools.StructuredTool.from_function) class method -- this is similar to the `@tool` decorator, but allows more configuration and specification of both sync and async implementations.\n",
|
||||
"1. Functions;\n",
|
||||
"2. LangChain [Runnables](/docs/concepts#runnable-interface);\n",
|
||||
"3. By sub-classing from [BaseTool](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html) -- This is the most flexible method, it provides the largest degree of control, at the expense of more effort and code.\n",
|
||||
"\n",
|
||||
"The `@tool` or the `StructuredTool.from_function` class method should be sufficient for most use cases.\n",
|
||||
"Creating tools from functions may be sufficient for most use cases, and can be done via a simple [@tool decorator](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html#langchain_core.tools.tool). If more configuration is needed-- e.g., specification of both sync and async implementations-- one can also use the [StructuredTool.from_function](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.StructuredTool.html#langchain_core.tools.StructuredTool.from_function) class method.\n",
|
||||
"\n",
|
||||
"In this guide we provide an overview of these methods.\n",
|
||||
"\n",
|
||||
":::{.callout-tip}\n",
|
||||
"\n",
|
||||
@@ -35,7 +37,9 @@
|
||||
"id": "c7326b23",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## @tool decorator\n",
|
||||
"## Creating tools from functions\n",
|
||||
"\n",
|
||||
"### @tool decorator\n",
|
||||
"\n",
|
||||
"This `@tool` decorator is the simplest way to define a custom tool. The decorator uses the function name as the tool name by default, but this can be overridden by passing a string as the first argument. Additionally, the decorator will use the function's docstring as the tool's description - so a docstring MUST be provided. "
|
||||
]
|
||||
@@ -51,7 +55,7 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"multiply\n",
|
||||
"multiply(a: int, b: int) -> int - Multiply two numbers.\n",
|
||||
"Multiply two numbers.\n",
|
||||
"{'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'integer'}}\n"
|
||||
]
|
||||
}
|
||||
@@ -96,6 +100,57 @@
|
||||
" return a * b"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8f0edc51-c586-414c-8941-c8abe779943f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note that `@tool` supports parsing of annotations, nested schemas, and other features:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "5626423f-053e-4a66-adca-1d794d835397",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'title': 'multiply_by_maxSchema',\n",
|
||||
" 'description': 'Multiply a by the maximum of b.',\n",
|
||||
" 'type': 'object',\n",
|
||||
" 'properties': {'a': {'title': 'A',\n",
|
||||
" 'description': 'scale factor',\n",
|
||||
" 'type': 'string'},\n",
|
||||
" 'b': {'title': 'B',\n",
|
||||
" 'description': 'list of ints over which to take maximum',\n",
|
||||
" 'type': 'array',\n",
|
||||
" 'items': {'type': 'integer'}}},\n",
|
||||
" 'required': ['a', 'b']}"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from typing import Annotated, List\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def multiply_by_max(\n",
|
||||
" a: Annotated[str, \"scale factor\"],\n",
|
||||
" b: Annotated[List[int], \"list of ints over which to take maximum\"],\n",
|
||||
") -> int:\n",
|
||||
" \"\"\"Multiply a by the maximum of b.\"\"\"\n",
|
||||
" return a * max(b)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"multiply_by_max.args_schema.schema()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "98d6eee9",
|
||||
@@ -106,7 +161,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"id": "9216d03a-f6ea-4216-b7e1-0661823a4c0b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -115,7 +170,7 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"multiplication-tool\n",
|
||||
"multiplication-tool(a: int, b: int) -> int - Multiply two numbers.\n",
|
||||
"Multiply two numbers.\n",
|
||||
"{'a': {'title': 'A', 'description': 'first number', 'type': 'integer'}, 'b': {'title': 'B', 'description': 'second number', 'type': 'integer'}}\n",
|
||||
"True\n"
|
||||
]
|
||||
@@ -143,19 +198,84 @@
|
||||
"print(multiply.return_direct)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "33a9e94d-0b60-48f3-a4c2-247dce096e66",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Docstring parsing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6d0cb586-93d4-4ff1-9779-71df7853cb68",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`@tool` can optionally parse [Google Style docstrings](https://google.github.io/styleguide/pyguide.html#383-functions-and-methods) and associate the docstring components (such as arg descriptions) to the relevant parts of the tool schema. To toggle this behavior, specify `parse_docstring`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "336f5538-956e-47d5-9bde-b732559f9e61",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'title': 'fooSchema',\n",
|
||||
" 'description': 'The foo.',\n",
|
||||
" 'type': 'object',\n",
|
||||
" 'properties': {'bar': {'title': 'Bar',\n",
|
||||
" 'description': 'The bar.',\n",
|
||||
" 'type': 'string'},\n",
|
||||
" 'baz': {'title': 'Baz', 'description': 'The baz.', 'type': 'integer'}},\n",
|
||||
" 'required': ['bar', 'baz']}"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@tool(parse_docstring=True)\n",
|
||||
"def foo(bar: str, baz: int) -> str:\n",
|
||||
" \"\"\"The foo.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" bar: The bar.\n",
|
||||
" baz: The baz.\n",
|
||||
" \"\"\"\n",
|
||||
" return bar\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"foo.args_schema.schema()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f18a2503-5393-421b-99fa-4a01dd824d0e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
":::{.callout-caution}\n",
|
||||
"By default, `@tool(parse_docstring=True)` will raise `ValueError` if the docstring does not parse correctly. See [API Reference](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html) for detail and examples.\n",
|
||||
":::"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b63fcc3b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## StructuredTool\n",
|
||||
"### StructuredTool\n",
|
||||
"\n",
|
||||
"The `StrurcturedTool.from_function` class method provides a bit more configurability than the `@tool` decorator, without requiring much additional code."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 6,
|
||||
"id": "564fbe6f-11df-402d-b135-ef6ff25e1e63",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -198,7 +318,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"id": "6bc055d4-1fbe-4db5-8881-9c382eba6b1b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -208,7 +328,7 @@
|
||||
"text": [
|
||||
"6\n",
|
||||
"Calculator\n",
|
||||
"Calculator(a: int, b: int) -> int - multiply numbers\n",
|
||||
"multiply numbers\n",
|
||||
"{'a': {'title': 'A', 'description': 'first number', 'type': 'integer'}, 'b': {'title': 'B', 'description': 'second number', 'type': 'integer'}}\n"
|
||||
]
|
||||
}
|
||||
@@ -239,6 +359,63 @@
|
||||
"print(calculator.args)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5517995d-54e3-449b-8fdb-03561f5e4647",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Creating tools from Runnables\n",
|
||||
"\n",
|
||||
"LangChain [Runnables](/docs/concepts#runnable-interface) that accept string or `dict` input can be converted to tools using the [as_tool](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.base.Runnable.html#langchain_core.runnables.base.Runnable.as_tool) method, which allows for the specification of names, descriptions, and additional schema information for arguments.\n",
|
||||
"\n",
|
||||
"Example usage:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "8ef593c5-cf72-4c10-bfc9-7d21874a0c24",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'answer_style': {'title': 'Answer Style', 'type': 'string'}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_core.language_models import GenericFakeChatModel\n",
|
||||
"from langchain_core.output_parsers import StrOutputParser\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [(\"human\", \"Hello. Please respond in the style of {answer_style}.\")]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Placeholder LLM\n",
|
||||
"llm = GenericFakeChatModel(messages=iter([\"hello matey\"]))\n",
|
||||
"\n",
|
||||
"chain = prompt | llm | StrOutputParser()\n",
|
||||
"\n",
|
||||
"as_tool = chain.as_tool(\n",
|
||||
" name=\"Style responder\", description=\"Description of when to use tool.\"\n",
|
||||
")\n",
|
||||
"as_tool.args"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0521b787-a146-45a6-8ace-ae1ac4669dd7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"See [this guide](/docs/how_to/convert_runnable_to_tool) for more detail."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b840074b-9c10-4ca0-aed8-626c52b2398f",
|
||||
@@ -251,7 +428,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 10,
|
||||
"id": "1dad8f8e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -300,7 +477,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 11,
|
||||
"id": "bb551c33",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -351,7 +528,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 12,
|
||||
"id": "6615cb77-fd4c-4676-8965-f92cc71d4944",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -383,7 +560,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 13,
|
||||
"id": "bb2af583-eadd-41f4-a645-bf8748bd3dcd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -428,7 +605,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 14,
|
||||
"id": "4ad0932c-8610-4278-8c57-f9218f654c8a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -473,7 +650,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 15,
|
||||
"id": "7094c0e8-6192-4870-a942-aad5b5ae48fd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -496,7 +673,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 16,
|
||||
"id": "b4d22022-b105-4ccc-a15b-412cb9ea3097",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -506,7 +683,7 @@
|
||||
"'Error: There is no city by the name of foobar.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -530,7 +707,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 17,
|
||||
"id": "3fad1728-d367-4e1b-9b54-3172981271cf",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -540,7 +717,7 @@
|
||||
"\"There is no such city, but it's probably above 0K there!\""
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -564,7 +741,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 18,
|
||||
"id": "ebfe7c1f-318d-4e58-99e1-f31e69473c46",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -574,7 +751,7 @@
|
||||
"'The following errors occurred during tool execution: `Error: There is no city by the name of foobar.`'"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -591,13 +768,189 @@
|
||||
"\n",
|
||||
"get_weather_tool.invoke({\"city\": \"foobar\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1a8d8383-11b3-445e-956f-df4e96995e00",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Returning artifacts of Tool execution\n",
|
||||
"\n",
|
||||
"Sometimes there are artifacts of a tool's execution that we want to make accessible to downstream components in our chain or agent, but that we don't want to expose to the model itself. For example if a tool returns custom objects like Documents, we may want to pass some view or metadata about this output to the model without passing the raw output to the model. At the same time, we may want to be able to access this full output elsewhere, for example in downstream tools.\n",
|
||||
"\n",
|
||||
"The Tool and [ToolMessage](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.tool.ToolMessage.html) interfaces make it possible to distinguish between the parts of the tool output meant for the model (this is the ToolMessage.content) and those parts which are meant for use outside the model (ToolMessage.artifact).\n",
|
||||
"\n",
|
||||
":::info Requires ``langchain-core >= 0.2.19``\n",
|
||||
"\n",
|
||||
"This functionality was added in ``langchain-core == 0.2.19``. Please make sure your package is up to date.\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"If we want our tool to distinguish between message content and other artifacts, we need to specify `response_format=\"content_and_artifact\"` when defining our tool and make sure that we return a tuple of (content, artifact):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "14905425-0334-43a0-9de9-5bcf622ede0e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"from typing import List, Tuple\n",
|
||||
"\n",
|
||||
"from langchain_core.tools import tool\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool(response_format=\"content_and_artifact\")\n",
|
||||
"def generate_random_ints(min: int, max: int, size: int) -> Tuple[str, List[int]]:\n",
|
||||
" \"\"\"Generate size random ints in the range [min, max].\"\"\"\n",
|
||||
" array = [random.randint(min, max) for _ in range(size)]\n",
|
||||
" content = f\"Successfully generated array of {size} random ints in [{min}, {max}].\"\n",
|
||||
" return content, array"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "49f057a6-8938-43ea-8faf-ae41e797ceb8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we invoke our tool directly with the tool arguments, we'll get back just the content part of the output:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "0f2e1528-404b-46e6-b87c-f0957c4b9217",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Successfully generated array of 10 random ints in [0, 9].'"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"generate_random_ints.invoke({\"min\": 0, \"max\": 9, \"size\": 10})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1e62ebba-1737-4b97-b61a-7313ade4e8c2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we invoke our tool with a ToolCall (like the ones generated by tool-calling models), we'll get back a ToolMessage that contains both the content and artifact generated by the Tool:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "cc197777-26eb-46b3-a83b-c2ce116c6311",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ToolMessage(content='Successfully generated array of 10 random ints in [0, 9].', name='generate_random_ints', tool_call_id='123', artifact=[1, 4, 2, 5, 3, 9, 0, 4, 7, 7])"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"generate_random_ints.invoke(\n",
|
||||
" {\n",
|
||||
" \"name\": \"generate_random_ints\",\n",
|
||||
" \"args\": {\"min\": 0, \"max\": 9, \"size\": 10},\n",
|
||||
" \"id\": \"123\", # required\n",
|
||||
" \"type\": \"tool_call\", # required\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dfdc1040-bf25-4790-b4c3-59452db84e11",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can do the same when subclassing BaseTool:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "fe1a09d1-378b-4b91-bb5e-0697c3d7eb92",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.tools import BaseTool\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GenerateRandomFloats(BaseTool):\n",
|
||||
" name: str = \"generate_random_floats\"\n",
|
||||
" description: str = \"Generate size random floats in the range [min, max].\"\n",
|
||||
" response_format: str = \"content_and_artifact\"\n",
|
||||
"\n",
|
||||
" ndigits: int = 2\n",
|
||||
"\n",
|
||||
" def _run(self, min: float, max: float, size: int) -> Tuple[str, List[float]]:\n",
|
||||
" range_ = max - min\n",
|
||||
" array = [\n",
|
||||
" round(min + (range_ * random.random()), ndigits=self.ndigits)\n",
|
||||
" for _ in range(size)\n",
|
||||
" ]\n",
|
||||
" content = f\"Generated {size} floats in [{min}, {max}], rounded to {self.ndigits} decimals.\"\n",
|
||||
" return content, array\n",
|
||||
"\n",
|
||||
" # Optionally define an equivalent async method\n",
|
||||
"\n",
|
||||
" # async def _arun(self, min: float, max: float, size: int) -> Tuple[str, List[float]]:\n",
|
||||
" # ..."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "8c3d16f6-1c4a-48ab-b05a-38547c592e79",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ToolMessage(content='Generated 3 floats in [0.1, 3.3333], rounded to 4 decimals.', name='generate_random_floats', tool_call_id='123', artifact=[1.4277, 0.7578, 2.4871])"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"rand_gen = GenerateRandomFloats(ndigits=4)\n",
|
||||
"\n",
|
||||
"rand_gen.invoke(\n",
|
||||
" {\n",
|
||||
" \"name\": \"generate_random_floats\",\n",
|
||||
" \"args\": {\"min\": 0.1, \"max\": 3.3333, \"size\": 3},\n",
|
||||
" \"id\": \"123\",\n",
|
||||
" \"type\": \"tool_call\",\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "poetry-venv-311",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "poetry-venv-311"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
@@ -609,7 +962,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.11.9"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
||||
@@ -67,15 +67,16 @@ If you'd prefer not to set an environment variable you can pass the key in direc
|
||||
```python
|
||||
from langchain_cohere import CohereEmbeddings
|
||||
|
||||
embeddings_model = CohereEmbeddings(cohere_api_key="...")
|
||||
embeddings_model = CohereEmbeddings(cohere_api_key="...", model='embed-english-v3.0')
|
||||
```
|
||||
|
||||
Otherwise you can initialize without any params:
|
||||
Otherwise you can initialize simply as shown below:
|
||||
```python
|
||||
from langchain_cohere import CohereEmbeddings
|
||||
|
||||
embeddings_model = CohereEmbeddings()
|
||||
embeddings_model = CohereEmbeddings(model='embed-english-v3.0')
|
||||
```
|
||||
Do note that it is mandatory to pass the model parameter while initializing the CohereEmbeddings class.
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="huggingface" label="Hugging Face">
|
||||
|
||||
@@ -84,7 +84,7 @@ These are the core building blocks you can use when building applications.
|
||||
- [How to: use chat model to call tools](/docs/how_to/tool_calling)
|
||||
- [How to: stream tool calls](/docs/how_to/tool_streaming)
|
||||
- [How to: few shot prompt tool behavior](/docs/how_to/tools_few_shot)
|
||||
- [How to: bind model-specific formated tools](/docs/how_to/tools_model_specific)
|
||||
- [How to: bind model-specific formatted tools](/docs/how_to/tools_model_specific)
|
||||
- [How to: force a specific tool call](/docs/how_to/tool_choice)
|
||||
- [How to: init any model in one line](/docs/how_to/chat_models_universal_init/)
|
||||
|
||||
@@ -195,7 +195,9 @@ LangChain [Tools](/docs/concepts/#tools) contain a description of the tool (to p
|
||||
- [How to: add a human in the loop to tool usage](/docs/how_to/tools_human)
|
||||
- [How to: handle errors when calling tools](/docs/how_to/tools_error)
|
||||
- [How to: disable parallel tool calling](/docs/how_to/tool_choice)
|
||||
- [How to: stream events from within a tool](/docs/how_to/tool_stream_events)
|
||||
- [How to: access the `RunnableConfig` object within a custom tool](/docs/how_to/tool_configure)
|
||||
- [How to: stream events from child runs within a custom tool](/docs/how_to/tool_stream_events)
|
||||
- [How to: return extra artifacts from a tool](/docs/how_to/tool_artifacts/)
|
||||
|
||||
### Multimodal
|
||||
|
||||
|
||||
@@ -63,6 +63,38 @@
|
||||
"Notice that if the contents of one of the messages to merge is a list of content blocks then the merged message will have a list of content blocks. And if both messages to merge have string contents then those are concatenated with a newline character."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "11f7e8d3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `merge_message_runs` utility also works with messages composed together using the overloaded `+` operation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b51855c5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"messages = (\n",
|
||||
" SystemMessage(\"you're a good assistant.\")\n",
|
||||
" + SystemMessage(\"you always respond with a joke.\")\n",
|
||||
" + HumanMessage([{\"type\": \"text\", \"text\": \"i wonder why it's called langchain\"}])\n",
|
||||
" + HumanMessage(\"and who is harrison chasing anyways\")\n",
|
||||
" + AIMessage(\n",
|
||||
" 'Well, I guess they thought \"WordRope\" and \"SentenceString\" just didn\\'t have the same ring to it!'\n",
|
||||
" )\n",
|
||||
" + AIMessage(\n",
|
||||
" \"Why, he's probably chasing after the last cup of coffee in the office!\"\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"merged = merge_message_runs(messages)\n",
|
||||
"print(\"\\n\\n\".join([repr(x) for x in merged]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1b2eee74-71c8-4168-b968-bca580c25d18",
|
||||
|
||||
395
docs/docs/how_to/tool_artifacts.ipynb
Normal file
395
docs/docs/how_to/tool_artifacts.ipynb
Normal file
@@ -0,0 +1,395 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "503e36ae-ca62-4f8a-880c-4fe78ff5df93",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# How to return extra artifacts from a tool\n",
|
||||
"\n",
|
||||
":::info Prerequisites\n",
|
||||
"This guide assumes familiarity with the following concepts:\n",
|
||||
"\n",
|
||||
"- [Tools](/docs/concepts/#tools)\n",
|
||||
"- [Function/tool calling](/docs/concepts/#functiontool-calling)\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"Tools are utilities that can be called by a model, and whose outputs are designed to be fed back to a model. Sometimes, however, there are artifacts of a tool's execution that we want to make accessible to downstream components in our chain or agent, but that we don't want to expose to the model itself. For example if a tool returns a custom object, a dataframe or an image, we may want to pass some metadata about this output to the model without passing the actual output to the model. At the same time, we may want to be able to access this full output elsewhere, for example in downstream tools.\n",
|
||||
"\n",
|
||||
"The Tool and [ToolMessage](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.tool.ToolMessage.html) interfaces make it possible to distinguish between the parts of the tool output meant for the model (this is the ToolMessage.content) and those parts which are meant for use outside the model (ToolMessage.artifact).\n",
|
||||
"\n",
|
||||
":::info Requires ``langchain-core >= 0.2.19``\n",
|
||||
"\n",
|
||||
"This functionality was added in ``langchain-core == 0.2.19``. Please make sure your package is up to date.\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"## Defining the tool\n",
|
||||
"\n",
|
||||
"If we want our tool to distinguish between message content and other artifacts, we need to specify `response_format=\"content_and_artifact\"` when defining our tool and make sure that we return a tuple of (content, artifact):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "762b9199-885f-4946-9c98-cc54d72b0d76",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -qU \"langchain-core>=0.2.19\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "b9eb179d-1f41-4748-9866-b3d3e8c73cd0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"from typing import List, Tuple\n",
|
||||
"\n",
|
||||
"from langchain_core.tools import tool\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool(response_format=\"content_and_artifact\")\n",
|
||||
"def generate_random_ints(min: int, max: int, size: int) -> Tuple[str, List[int]]:\n",
|
||||
" \"\"\"Generate size random ints in the range [min, max].\"\"\"\n",
|
||||
" array = [random.randint(min, max) for _ in range(size)]\n",
|
||||
" content = f\"Successfully generated array of {size} random ints in [{min}, {max}].\"\n",
|
||||
" return content, array"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0ab05d25-af4a-4e5a-afe2-f090416d7ee7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Invoking the tool with ToolCall\n",
|
||||
"\n",
|
||||
"If we directly invoke our tool with just the tool arguments, you'll notice that we only get back the content part of the Tool output:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "5e7d5e77-3102-4a59-8ade-e4e699dd1817",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Successfully generated array of 10 random ints in [0, 9].'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Failed to batch ingest runs: LangSmithRateLimitError('Rate limit exceeded for https://api.smith.langchain.com/runs/batch. HTTPError(\\'429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/batch\\', \\'{\"detail\":\"Monthly unique traces usage limit exceeded\"}\\')')\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"generate_random_ints.invoke({\"min\": 0, \"max\": 9, \"size\": 10})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "30db7228-f04c-489e-afda-9a572eaa90a1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In order to get back both the content and the artifact, we need to invoke our model with a ToolCall (which is just a dictionary with \"name\", \"args\", \"id\" and \"type\" keys), which has additional info needed to generate a ToolMessage like the tool call ID:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "da1d939d-a900-4b01-92aa-d19011a6b034",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ToolMessage(content='Successfully generated array of 10 random ints in [0, 9].', name='generate_random_ints', tool_call_id='123', artifact=[2, 8, 0, 6, 0, 0, 1, 5, 0, 0])"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"generate_random_ints.invoke(\n",
|
||||
" {\n",
|
||||
" \"name\": \"generate_random_ints\",\n",
|
||||
" \"args\": {\"min\": 0, \"max\": 9, \"size\": 10},\n",
|
||||
" \"id\": \"123\", # required\n",
|
||||
" \"type\": \"tool_call\", # required\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a3cfc03d-020b-42c7-b0f8-c824af19e45e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using with a model\n",
|
||||
"\n",
|
||||
"With a [tool-calling model](/docs/how_to/tool_calling/), we can easily use a model to call our Tool and generate ToolMessages:\n",
|
||||
"\n",
|
||||
"```{=mdx}\n",
|
||||
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
|
||||
"\n",
|
||||
"<ChatModelTabs\n",
|
||||
" customVarName=\"llm\"\n",
|
||||
"/>\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "74de0286-b003-4b48-9cdd-ecab435515ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# | echo: false\n",
|
||||
"# | output: false\n",
|
||||
"\n",
|
||||
"from langchain_anthropic import ChatAnthropic\n",
|
||||
"\n",
|
||||
"llm = ChatAnthropic(model=\"claude-3-5-sonnet-20240620\", temperature=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "8a67424b-d19c-43df-ac7b-690bca42146c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'name': 'generate_random_ints',\n",
|
||||
" 'args': {'min': 1, 'max': 24, 'size': 6},\n",
|
||||
" 'id': 'toolu_01EtALY3Wz1DVYhv1TLvZGvE',\n",
|
||||
" 'type': 'tool_call'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm_with_tools = llm.bind_tools([generate_random_ints])\n",
|
||||
"\n",
|
||||
"ai_msg = llm_with_tools.invoke(\"generate 6 positive ints less than 25\")\n",
|
||||
"ai_msg.tool_calls"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "00c4e906-3ca8-41e8-a0be-65cb0db7d574",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ToolMessage(content='Successfully generated array of 6 random ints in [1, 24].', name='generate_random_ints', tool_call_id='toolu_01EtALY3Wz1DVYhv1TLvZGvE', artifact=[2, 20, 23, 8, 1, 15])"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"generate_random_ints.invoke(ai_msg.tool_calls[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ddef2690-70de-4542-ab20-2337f77f3e46",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we just pass in the tool call args, we'll only get back the content:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "f4a6c9a6-0ffc-4b0e-a59f-f3c3d69d824d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Successfully generated array of 6 random ints in [1, 24].'"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"generate_random_ints.invoke(ai_msg.tool_calls[0][\"args\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "98d6443b-ff41-4d91-8523-b6274fc74ee5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we wanted to declaratively create a chain, we could do this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "eb55ec23-95a4-464e-b886-d9679bf3aaa2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[ToolMessage(content='Successfully generated array of 1 random ints in [1, 5].', name='generate_random_ints', tool_call_id='toolu_01FwYhnkwDPJPbKdGq4ng6uD', artifact=[5])]"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from operator import attrgetter\n",
|
||||
"\n",
|
||||
"chain = llm_with_tools | attrgetter(\"tool_calls\") | generate_random_ints.map()\n",
|
||||
"\n",
|
||||
"chain.invoke(\"give me a random number between 1 and 5\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4df46be2-babb-4bfe-a641-91cd3d03ffaf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Creating from BaseTool class\n",
|
||||
"\n",
|
||||
"If you want to create a BaseTool object directly, instead of decorating a function with `@tool`, you can do so like this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "9a9129e1-6aee-4a10-ad57-62ef3bf0276c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.tools import BaseTool\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GenerateRandomFloats(BaseTool):\n",
|
||||
" name: str = \"generate_random_floats\"\n",
|
||||
" description: str = \"Generate size random floats in the range [min, max].\"\n",
|
||||
" response_format: str = \"content_and_artifact\"\n",
|
||||
"\n",
|
||||
" ndigits: int = 2\n",
|
||||
"\n",
|
||||
" def _run(self, min: float, max: float, size: int) -> Tuple[str, List[float]]:\n",
|
||||
" range_ = max - min\n",
|
||||
" array = [\n",
|
||||
" round(min + (range_ * random.random()), ndigits=self.ndigits)\n",
|
||||
" for _ in range(size)\n",
|
||||
" ]\n",
|
||||
" content = f\"Generated {size} floats in [{min}, {max}], rounded to {self.ndigits} decimals.\"\n",
|
||||
" return content, array\n",
|
||||
"\n",
|
||||
" # Optionally define an equivalent async method\n",
|
||||
"\n",
|
||||
" # async def _arun(self, min: float, max: float, size: int) -> Tuple[str, List[float]]:\n",
|
||||
" # ..."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "d7322619-f420-4b29-8ee5-023e693d0179",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Generated 3 floats in [0.1, 3.3333], rounded to 4 decimals.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"rand_gen = GenerateRandomFloats(ndigits=4)\n",
|
||||
"rand_gen.invoke({\"min\": 0.1, \"max\": 3.3333, \"size\": 3})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "0892f277-23a6-4bb8-a0e9-59f533ac9750",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ToolMessage(content='Generated 3 floats in [0.1, 3.3333], rounded to 4 decimals.', name='generate_random_floats', tool_call_id='123', artifact=[1.5789, 2.464, 2.2719])"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"rand_gen.invoke(\n",
|
||||
" {\n",
|
||||
" \"name\": \"generate_random_floats\",\n",
|
||||
" \"args\": {\"min\": 0.1, \"max\": 3.3333, \"size\": 3},\n",
|
||||
" \"id\": \"123\",\n",
|
||||
" \"type\": \"tool_call\",\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "poetry-venv-311",
|
||||
"language": "python",
|
||||
"name": "poetry-venv-311"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
132
docs/docs/how_to/tool_configure.ipynb
Normal file
132
docs/docs/how_to/tool_configure.ipynb
Normal file
@@ -0,0 +1,132 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# How to access the RunnableConfig object within a custom tool\n",
|
||||
"\n",
|
||||
":::info Prerequisites\n",
|
||||
"\n",
|
||||
"This guide assumes familiarity with the following concepts:\n",
|
||||
"\n",
|
||||
"- [LangChain Tools](/docs/concepts/#tools)\n",
|
||||
"- [Custom tools](/docs/how_to/custom_tools)\n",
|
||||
"- [LangChain Expression Language (LCEL)](/docs/concepts/#langchain-expression-language-lcel)\n",
|
||||
"- [Configuring runnable behavior](/docs/how_to/configure/)\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"If you have a tool that call chat models, retrievers, or other runnables, you may want to access internal events from those runnables or configure them with additional properties. This guide shows you how to manually pass parameters properly so that you can do this using the `astream_events()` method.\n",
|
||||
"\n",
|
||||
"Tools are runnables, and you can treat them the same way as any other runnable at the interface level - you can call `invoke()`, `batch()`, and `stream()` on them as normal. However, when writing custom tools, you may want to invoke other runnables like chat models or retrievers. In order to properly trace and configure those sub-invocations, you'll need to manually access and pass in the tool's current [`RunnableConfig`](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.config.RunnableConfig.html) object. This guide show you some examples of how to do that.\n",
|
||||
"\n",
|
||||
":::caution Compatibility\n",
|
||||
"\n",
|
||||
"This guide requires `langchain-core>=0.2.16`.\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"## Inferring by parameter type\n",
|
||||
"\n",
|
||||
"To access reference the active config object from your custom tool, you'll need to add a parameter to your tool's signature typed as `RunnableConfig`. When you invoke your tool, LangChain will inspect your tool's signature, look for a parameter typed as `RunnableConfig`, and if it exists, populate that parameter with the correct value.\n",
|
||||
"\n",
|
||||
"**Note:** The actual name of the parameter doesn't matter, only the typing.\n",
|
||||
"\n",
|
||||
"To illustrate this, define a custom tool that takes a two parameters - one typed as a string, the other typed as `RunnableConfig`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -qU langchain_core"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.runnables import RunnableConfig\n",
|
||||
"from langchain_core.tools import tool\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"async def reverse_tool(text: str, special_config_param: RunnableConfig) -> str:\n",
|
||||
" \"\"\"A test tool that combines input text with a configurable parameter.\"\"\"\n",
|
||||
" return (text + special_config_param[\"configurable\"][\"additional_field\"])[::-1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Then, if we invoke the tool with a `config` containing a `configurable` field, we can see that `additional_field` is passed through correctly:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'321cba'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await reverse_tool.ainvoke(\n",
|
||||
" {\"text\": \"abc\"}, config={\"configurable\": {\"additional_field\": \"123\"}}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Next steps\n",
|
||||
"\n",
|
||||
"You've now seen how to configure and stream events from within a tool. Next, check out the following guides for more on using tools:\n",
|
||||
"\n",
|
||||
"- [Stream events from child runs within a custom tool](/docs/how_to/tool_stream_events/)\n",
|
||||
"- Pass [tool results back to a model](/docs/how_to/tool_results_pass_to_model)\n",
|
||||
"\n",
|
||||
"You can also check out some more specific uses of tool calling:\n",
|
||||
"\n",
|
||||
"- Building [tool-using chains and agents](/docs/how_to#tools)\n",
|
||||
"- Getting [structured outputs](/docs/how_to/structured_output/) from models"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -6,12 +6,20 @@
|
||||
"source": [
|
||||
"# How to pass tool outputs to the model\n",
|
||||
"\n",
|
||||
"If we're using the model-generated tool invocations to actually call tools and want to pass the tool results back to the model, we can do so using `ToolMessage`s. First, let's define our tools and our model."
|
||||
":::info Prerequisites\n",
|
||||
"This guide assumes familiarity with the following concepts:\n",
|
||||
"\n",
|
||||
"- [Tools](/docs/concepts/#tools)\n",
|
||||
"- [Function/tool calling](/docs/concepts/#functiontool-calling)\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"If we're using the model-generated tool invocations to actually call tools and want to pass the tool results back to the model, we can do so using `ToolMessage`s and `ToolCall`s. First, let's define our tools and our model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -35,7 +43,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -54,25 +62,32 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we can use ``ToolMessage`` to pass back the output of the tool calls to the model."
|
||||
"The nice thing about Tools is that if we invoke them with a ToolCall, we'll automatically get back a ToolMessage that can be fed back to the model: \n",
|
||||
"\n",
|
||||
":::info Requires ``langchain-core >= 0.2.19``\n",
|
||||
"\n",
|
||||
"This functionality was added in ``langchain-core == 0.2.19``. Please make sure your package is up to date.\n",
|
||||
"\n",
|
||||
":::"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[HumanMessage(content='What is 3 * 12? Also, what is 11 + 49?'),\n",
|
||||
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_svc2GLSxNFALbaCAbSjMI9J8', 'function': {'arguments': '{\"a\": 3, \"b\": 12}', 'name': 'Multiply'}, 'type': 'function'}, {'id': 'call_r8jxte3zW6h3MEGV3zH2qzFh', 'function': {'arguments': '{\"a\": 11, \"b\": 49}', 'name': 'Add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 50, 'prompt_tokens': 105, 'total_tokens': 155}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_d9767fc5b9', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-a79ad1dd-95f1-4a46-b688-4c83f327a7b3-0', tool_calls=[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_svc2GLSxNFALbaCAbSjMI9J8'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_r8jxte3zW6h3MEGV3zH2qzFh'}]),\n",
|
||||
" ToolMessage(content='36', tool_call_id='call_svc2GLSxNFALbaCAbSjMI9J8'),\n",
|
||||
" ToolMessage(content='60', tool_call_id='call_r8jxte3zW6h3MEGV3zH2qzFh')]"
|
||||
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Smg3NHJNxrKfAmd4f9GkaYn3', 'function': {'arguments': '{\"a\": 3, \"b\": 12}', 'name': 'multiply'}, 'type': 'function'}, {'id': 'call_55K1C0DmH6U5qh810gW34xZ0', 'function': {'arguments': '{\"a\": 11, \"b\": 49}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 49, 'prompt_tokens': 88, 'total_tokens': 137}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-56657feb-96dd-456c-ab8e-1857eab2ade0-0', tool_calls=[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_Smg3NHJNxrKfAmd4f9GkaYn3', 'type': 'tool_call'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_55K1C0DmH6U5qh810gW34xZ0', 'type': 'tool_call'}], usage_metadata={'input_tokens': 88, 'output_tokens': 49, 'total_tokens': 137}),\n",
|
||||
" ToolMessage(content='36', name='multiply', tool_call_id='call_Smg3NHJNxrKfAmd4f9GkaYn3'),\n",
|
||||
" ToolMessage(content='60', name='add', tool_call_id='call_55K1C0DmH6U5qh810gW34xZ0')]"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -85,24 +100,25 @@
|
||||
"messages.append(ai_msg)\n",
|
||||
"for tool_call in ai_msg.tool_calls:\n",
|
||||
" selected_tool = {\"add\": add, \"multiply\": multiply}[tool_call[\"name\"].lower()]\n",
|
||||
" tool_output = selected_tool.invoke(tool_call[\"args\"])\n",
|
||||
" messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n",
|
||||
" tool_msg = selected_tool.invoke(tool_call)\n",
|
||||
" messages.append(tool_msg)\n",
|
||||
"messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='3 * 12 is 36 and 11 + 49 is 60.', response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 171, 'total_tokens': 189}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_d9767fc5b9', 'finish_reason': 'stop', 'logprobs': None}, id='run-20b52149-e00d-48ea-97cf-f8de7a255f8c-0')"
|
||||
"AIMessage(content='3 * 12 is 36 and 11 + 49 is 60.', response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 153, 'total_tokens': 171}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-ba5032f0-f773-406d-a408-8314e66511d0-0', usage_metadata={'input_tokens': 153, 'output_tokens': 18, 'total_tokens': 171})"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -118,10 +134,24 @@
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "poetry-venv-311",
|
||||
"language": "python",
|
||||
"name": "poetry-venv-311"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
||||
@@ -28,17 +28,21 @@
|
||||
"which shows how to create an agent that keeps track of a given user's favorite pets.\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"There are times where tools need to use runtime values that should not be populated by the LLM. For example, the tool logic may require using the ID of the user who made the request. In this case, allowing the LLM to control the parameter is a security risk.\n",
|
||||
"You may need to bind values to a tool that are only known at runtime. For example, the tool logic may require using the ID of the user who made the request.\n",
|
||||
"\n",
|
||||
"Instead, the LLM should only control the parameters of the tool that are meant to be controlled by the LLM, while other parameters (such as user ID) should be fixed by the application logic. These defined parameters should not be part of the tool's final schema.\n",
|
||||
"Most of the time, such values should not be controlled by the LLM. In fact, allowing the LLM to control the user ID may lead to a security risk.\n",
|
||||
"\n",
|
||||
"This how-to guide shows some design patterns that create the tool dynamically at run time and binds appropriate values to them."
|
||||
"Instead, the LLM should only control the parameters of the tool that are meant to be controlled by the LLM, while other parameters (such as user ID) should be fixed by the application logic.\n",
|
||||
"\n",
|
||||
"This how-to guide shows a simple design pattern that creates the tool dynamically at run time and binds to them appropriate values."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can bind them to chat models as follows:\n",
|
||||
"\n",
|
||||
"```{=mdx}\n",
|
||||
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
|
||||
"\n",
|
||||
@@ -51,14 +55,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# | output: false\n",
|
||||
"# | echo: false\n",
|
||||
"\n",
|
||||
"%pip install -qU langchain_core langchain_openai\n",
|
||||
"%pip install -qU langchain langchain_openai\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"from getpass import getpass\n",
|
||||
@@ -75,17 +90,10 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using the `curry` utility function\n",
|
||||
"# Passing request time information\n",
|
||||
"\n",
|
||||
":::caution Compatibility\n",
|
||||
"\n",
|
||||
"This function is only available in `langchain_core>=0.2.17`.\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"We can bind arguments to the tool's inner function via a utility wrapper. This will use a technique called [currying](https://en.wikipedia.org/wiki/Currying) to bind arguments to the function while also removing it from the function signature.\n",
|
||||
"\n",
|
||||
"Below, we initialize a tool that lists a user's favorite pet. It requires a `user_id` that we'll curry ahead of time."
|
||||
"The idea is to create the tool dynamically at request time, and bind to it the appropriate information. For example,\n",
|
||||
"this information may be the user ID as resolved from the request itself."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -94,98 +102,18 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.tools import StructuredTool\n",
|
||||
"from langchain_core.utils.curry import curry\n",
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"user_to_pets = {\"eugene\": [\"cats\"]}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def list_favorite_pets(user_id: str) -> None:\n",
|
||||
" \"\"\"List favorite pets, if any.\"\"\"\n",
|
||||
" return user_to_pets.get(user_id, [])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"curried_function = curry(list_favorite_pets, user_id=\"eugene\")\n",
|
||||
"\n",
|
||||
"curried_tool = StructuredTool.from_function(curried_function)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we examine the schema of the curried tool, we can see that it no longer has `user_id` as part of its signature:"
|
||||
"from langchain_core.output_parsers import JsonOutputParser\n",
|
||||
"from langchain_core.tools import BaseTool, tool"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'title': 'list_favorite_petsSchema',\n",
|
||||
" 'description': 'List favorite pets, if any.',\n",
|
||||
" 'type': 'object',\n",
|
||||
" 'properties': {}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"curried_tool.input_schema.schema()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"But if we invoke it, we can see that it returns Eugene's favorite pets, `cats`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['cats']"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"curried_tool.invoke({})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using scope\n",
|
||||
"\n",
|
||||
"We can achieve a similar result by wrapping the tool declarations themselves in a function. This lets us take advantage of the closure created by the wrapper to pass a variable into each tool. Here's an example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"from langchain_core.tools import BaseTool, tool\n",
|
||||
"\n",
|
||||
"user_to_pets = {}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -205,7 +133,7 @@
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def list_favorite_pets() -> None:\n",
|
||||
" \"\"\"List favorite pets, if any.\"\"\"\n",
|
||||
" \"\"\"List favorite pets if any.\"\"\"\n",
|
||||
" return user_to_pets.get(user_id, [])\n",
|
||||
"\n",
|
||||
" return [update_favorite_pets, delete_favorite_pets, list_favorite_pets]"
|
||||
@@ -215,12 +143,12 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Verify that the tools work correctly:"
|
||||
"Verify that the tools work correctly"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -241,14 +169,21 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def handle_run_time_request(user_id: str, query: str):\n",
|
||||
" \"\"\"Handle run time request.\"\"\"\n",
|
||||
" tools = generate_tools_for_user(user_id)\n",
|
||||
" llm_with_tools = llm.bind_tools(tools)\n",
|
||||
" prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [(\"system\", \"You are a helpful assistant.\")],\n",
|
||||
" )\n",
|
||||
" chain = prompt | llm_with_tools\n",
|
||||
" return llm_with_tools.invoke(query)"
|
||||
]
|
||||
},
|
||||
@@ -261,7 +196,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -269,10 +204,10 @@
|
||||
"text/plain": [
|
||||
"[{'name': 'update_favorite_pets',\n",
|
||||
" 'args': {'pets': ['cats', 'parrots']},\n",
|
||||
" 'id': 'call_c8agYHY1COFSAgwZR11OGCmQ'}]"
|
||||
" 'id': 'call_jJvjPXsNbFO5MMgW0q84iqCN'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -313,7 +248,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.5"
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -4,25 +4,31 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# How to stream events from within a tool\n",
|
||||
"# How to stream events from child runs within a custom tool\n",
|
||||
"\n",
|
||||
":::info Prerequisites\n",
|
||||
"\n",
|
||||
"This guide assumes familiarity with the following concepts:\n",
|
||||
"- [LangChain Tools](/docs/concepts/#tools)\n",
|
||||
"- [Custom tools](/docs/how_to/custom_tools)\n",
|
||||
"- [Using stream events](/docs/how_to/streaming/#using-stream-events)\n",
|
||||
"- [Accessing RunnableConfig within a custom tool](/docs/how_to/tool_configure/)\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"If you have tools that call LLMs, retrievers, or other runnables, you may want to access internal events from those runnables. This guide shows you a few ways you can do this using the `astream_events()` method.\n",
|
||||
"If you have tools that call chat models, retrievers, or other runnables, you may want to access internal events from those runnables or configure them with additional properties. This guide shows you how to manually pass parameters properly so that you can do this using the `astream_events()` method.\n",
|
||||
"\n",
|
||||
":::caution Compatibility\n",
|
||||
"\n",
|
||||
"LangChain cannot automatically propagate configuration, including callbacks necessary for `astream_events()`, to child runnables if you are running `async` code in `python<=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
|
||||
"\n",
|
||||
"If you are running `python>=3.11`, configuration will automatically propagate to child runnables in async environments, and you don't need to access the `RunnableConfig` object for that tool as shown in this guide. However, it is still a good idea if your code may run in other Python versions.\n",
|
||||
"\n",
|
||||
"This guide also requires `langchain-core>=0.2.16`.\n",
|
||||
"\n",
|
||||
":::caution\n",
|
||||
"LangChain cannot automatically propagate callbacks to child runnables if you are running async code in python<=3.10.\n",
|
||||
" \n",
|
||||
"This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"We'll define a custom tool below that calls a chain that summarizes its input in a special way by prompting an LLM to return only 10 words, then reversing the output:\n",
|
||||
"Say you have a custom tool that calls a chain that condenses its input by prompting a chat model to return only 10 words, then reversing the output. First, define it in a naive way:\n",
|
||||
"\n",
|
||||
"```{=mdx}\n",
|
||||
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
|
||||
@@ -40,7 +46,7 @@
|
||||
"# | output: false\n",
|
||||
"# | echo: false\n",
|
||||
"\n",
|
||||
"%pip install -qU langchain langchain_anthropic\n",
|
||||
"%pip install -qU langchain langchain_anthropic langchain_core\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"from getpass import getpass\n",
|
||||
@@ -65,7 +71,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def special_summarization_tool(long_text: str) -> str:\n",
|
||||
"async def special_summarization_tool(long_text: str) -> str:\n",
|
||||
" \"\"\"A tool that summarizes input text using advanced techniques.\"\"\"\n",
|
||||
" prompt = ChatPromptTemplate.from_template(\n",
|
||||
" \"You are an expert writer. Summarize the following text in 10 words or less:\\n\\n{long_text}\"\n",
|
||||
@@ -75,7 +81,7 @@
|
||||
" return x[::-1]\n",
|
||||
"\n",
|
||||
" chain = prompt | model | StrOutputParser() | reverse\n",
|
||||
" summary = chain.invoke({\"long_text\": long_text})\n",
|
||||
" summary = await chain.ainvoke({\"long_text\": long_text})\n",
|
||||
" return summary"
|
||||
]
|
||||
},
|
||||
@@ -83,7 +89,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you just invoke the tool directly, you can see that you only get the final response:"
|
||||
"Invoking the tool directly works just fine:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -116,31 +122,90 @@
|
||||
"Coming! Hang on a second.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"special_summarization_tool.invoke({\"long_text\": LONG_TEXT})"
|
||||
"await special_summarization_tool.ainvoke({\"long_text\": LONG_TEXT})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you wanted to access the raw output from the chat model, you could use the [`astream_events()`](/docs/how_to/streaming/#using-stream-events) method and look for `on_chat_model_end` events:"
|
||||
"But if you wanted to access the raw output from the chat model rather than the full tool, you might try to use the [`astream_events()`](/docs/how_to/streaming/#using-stream-events) method and look for an `on_chat_model_end` event. Here's what happens:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"stream = special_summarization_tool.astream_events(\n",
|
||||
" {\"long_text\": LONG_TEXT}, version=\"v2\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"async for event in stream:\n",
|
||||
" if event[\"event\"] == \"on_chat_model_end\":\n",
|
||||
" # Never triggers in python<=3.10!\n",
|
||||
" print(event)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You'll notice (unless you're running through this guide in `python>=3.11`) that there are no chat model events emitted from the child run!\n",
|
||||
"\n",
|
||||
"This is because the example above does not pass the tool's config object into the internal chain. To fix this, redefine your tool to take a special parameter typed as `RunnableConfig` (see [this guide](/docs/how_to/tool_configure) for more details). You'll also need to pass that parameter through into the internal chain when executing it:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.runnables import RunnableConfig\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"async def special_summarization_tool_with_config(\n",
|
||||
" long_text: str, config: RunnableConfig\n",
|
||||
") -> str:\n",
|
||||
" \"\"\"A tool that summarizes input text using advanced techniques.\"\"\"\n",
|
||||
" prompt = ChatPromptTemplate.from_template(\n",
|
||||
" \"You are an expert writer. Summarize the following text in 10 words or less:\\n\\n{long_text}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def reverse(x: str):\n",
|
||||
" return x[::-1]\n",
|
||||
"\n",
|
||||
" chain = prompt | model | StrOutputParser() | reverse\n",
|
||||
" # Pass the \"config\" object as an argument to any executed runnables\n",
|
||||
" summary = await chain.ainvoke({\"long_text\": long_text}, config=config)\n",
|
||||
" return summary"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"And now try the same `astream_events()` call as before with your new tool:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'event': 'on_chat_model_end', 'data': {'output': AIMessage(content='Bee defies physics; Barry chooses outfit for graduation day.', response_metadata={'stop_reason': 'end_turn', 'stop_sequence': None}, id='run-195c0986-2ffa-43a3-9366-f2f96c42fe57', usage_metadata={'input_tokens': 182, 'output_tokens': 16, 'total_tokens': 198}), 'input': {'messages': [[HumanMessage(content=\"You are an expert writer. Summarize the following text in 10 words or less:\\n\\n\\nNARRATOR:\\n(Black screen with text; The sound of buzzing bees can be heard)\\nAccording to all known laws of aviation, there is no way a bee should be able to fly. Its wings are too small to get its fat little body off the ground. The bee, of course, flies anyway because bees don't care what humans think is impossible.\\nBARRY BENSON:\\n(Barry is picking out a shirt)\\nYellow, black. Yellow, black. Yellow, black. Yellow, black. Ooh, black and yellow! Let's shake it up a little.\\nJANET BENSON:\\nBarry! Breakfast is ready!\\nBARRY:\\nComing! Hang on a second.\\n\")]]}}, 'run_id': '195c0986-2ffa-43a3-9366-f2f96c42fe57', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['370919df-1bc3-43ae-aab2-8e112a4ddf47', 'de535624-278b-4927-9393-6d0cac3248df']}\n"
|
||||
"{'event': 'on_chat_model_end', 'data': {'output': AIMessage(content='Bee defies physics; Barry chooses outfit for graduation day.', response_metadata={'stop_reason': 'end_turn', 'stop_sequence': None}, id='run-d23abc80-0dce-4f74-9d7b-fb98ca4f2a9e', usage_metadata={'input_tokens': 182, 'output_tokens': 16, 'total_tokens': 198}), 'input': {'messages': [[HumanMessage(content=\"You are an expert writer. Summarize the following text in 10 words or less:\\n\\n\\nNARRATOR:\\n(Black screen with text; The sound of buzzing bees can be heard)\\nAccording to all known laws of aviation, there is no way a bee should be able to fly. Its wings are too small to get its fat little body off the ground. The bee, of course, flies anyway because bees don't care what humans think is impossible.\\nBARRY BENSON:\\n(Barry is picking out a shirt)\\nYellow, black. Yellow, black. Yellow, black. Yellow, black. Ooh, black and yellow! Let's shake it up a little.\\nJANET BENSON:\\nBarry! Breakfast is ready!\\nBARRY:\\nComing! Hang on a second.\\n\")]]}}, 'run_id': 'd23abc80-0dce-4f74-9d7b-fb98ca4f2a9e', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['f25c41fe-8972-4893-bc40-cecf3922c1fa']}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"stream = special_summarization_tool.astream_events(\n",
|
||||
"stream = special_summarization_tool_with_config.astream_events(\n",
|
||||
" {\"long_text\": LONG_TEXT}, version=\"v2\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
@@ -153,38 +218,38 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"And you can see that you get the raw response from the chat model.\n",
|
||||
"Awesome! This time there's an event emitted.\n",
|
||||
"\n",
|
||||
"`astream_events()` will automatically call internal runnables in a chain with streaming enabled if possible, so if you wanted to a stream of tokens as they are generated from the chat model, you could simply filter our calls to look for `on_chat_model_stream` events with no other changes:"
|
||||
"For streaming, `astream_events()` automatically calls internal runnables in a chain with streaming enabled if possible, so if you wanted to a stream of tokens as they are generated from the chat model, you could simply filter to look for `on_chat_model_stream` events with no other changes:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', usage_metadata={'input_tokens': 182, 'output_tokens': 0, 'total_tokens': 182})}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='Bee', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' def', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='ies physics', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=';', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' Barry', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' cho', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='oses outfit', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' for', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' graduation', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' day', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='.', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='', response_metadata={'stop_reason': 'end_turn', 'stop_sequence': None}, id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', usage_metadata={'input_tokens': 0, 'output_tokens': 16, 'total_tokens': 16})}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n"
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42', usage_metadata={'input_tokens': 182, 'output_tokens': 0, 'total_tokens': 182})}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='Bee', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' def', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='ies physics', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=';', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' Barry', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' cho', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='oses outfit', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' for', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' graduation', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' day', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='.', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='', response_metadata={'stop_reason': 'end_turn', 'stop_sequence': None}, id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42', usage_metadata={'input_tokens': 0, 'output_tokens': 16, 'total_tokens': 16})}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"stream = special_summarization_tool.astream_events(\n",
|
||||
"stream = special_summarization_tool_with_config.astream_events(\n",
|
||||
" {\"long_text\": LONG_TEXT}, version=\"v2\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
@@ -193,65 +258,14 @@
|
||||
" print(event)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note that you still have access to the final tool response as well. You can access it by looking for an `on_tool_end` event.\n",
|
||||
"\n",
|
||||
"To make events your tool emits easier to identify, you can also add identifiers to runnables using the `with_config()` method. `run_name` will apply to only to the runnable you attach it to, while `tags` will be inherited by runnables called within your initial runnable.\n",
|
||||
"\n",
|
||||
"Let's redeclare the tool with a tag, then run it with `astream_events()` with some filters. You should only see streamed events from the chat model and the final tool output:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630', usage_metadata={'input_tokens': 182, 'output_tokens': 0, 'total_tokens': 182})}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='Bee', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' def', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='ies physics', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=';', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' Barry', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' cho', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='oses outfit', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' for', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' graduation', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' day', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='.', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='', response_metadata={'stop_reason': 'end_turn', 'stop_sequence': None}, id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630', usage_metadata={'input_tokens': 0, 'output_tokens': 16, 'total_tokens': 16})}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
|
||||
"{'event': 'on_tool_end', 'data': {'output': '.yad noitaudarg rof tiftuo sesoohc yrraB ;scisyhp seifed eeB'}, 'run_id': '49d9d7d3-2b02-4964-a6c5-12f57a063146', 'name': 'special_summarization_tool', 'tags': ['bee_movie'], 'metadata': {}, 'parent_ids': []}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tagged_tool = special_summarization_tool.with_config({\"tags\": [\"bee_movie\"]})\n",
|
||||
"\n",
|
||||
"stream = tagged_tool.astream_events(\n",
|
||||
" {\"long_text\": LONG_TEXT}, version=\"v2\", include_tags=[\"bee_movie\"]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"async for event in stream:\n",
|
||||
" event_type = event[\"event\"]\n",
|
||||
" if event_type == \"on_chat_model_stream\" or event_type == \"on_tool_end\":\n",
|
||||
" print(event)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Next steps\n",
|
||||
"\n",
|
||||
"Now you've learned how to stream events from within a tool. Next, you can learn more about how to use tools:\n",
|
||||
"You've now seen how to stream events from within a tool. Next, check out the following guides for more on using tools:\n",
|
||||
"\n",
|
||||
"- Bind [model-specific tools](/docs/how_to/tools_model_specific/)\n",
|
||||
"- Pass [runtime values to tools](/docs/how_to/tool_runtime)\n",
|
||||
"- Pass [tool results back to a model](/docs/how_to/tool_results_pass_to_model)\n",
|
||||
"\n",
|
||||
|
||||
@@ -419,13 +419,13 @@
|
||||
"Invoking: `exponentiate` with `{'base': 405, 'exponent': 2}`\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[0m\u001b[38;5;200m\u001b[1;3m164025\u001b[0m\u001b[32;1m\u001b[1;3mThe result of taking 3 to the fifth power is 243. \n",
|
||||
"\u001b[0m\u001b[38;5;200m\u001b[1;3m13286025\u001b[0m\u001b[32;1m\u001b[1;3mThe result of taking 3 to the fifth power is 243. \n",
|
||||
"\n",
|
||||
"The sum of twelve and three is 15. \n",
|
||||
"\n",
|
||||
"Multiplying 243 by 15 gives 3645. \n",
|
||||
"\n",
|
||||
"Finally, squaring 3645 gives 164025.\u001b[0m\n",
|
||||
"Finally, squaring 3645 gives 13286025.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
@@ -434,7 +434,7 @@
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'input': 'Take 3 to the fifth power and multiply that by the sum of twelve and three, then square the whole result',\n",
|
||||
" 'output': 'The result of taking 3 to the fifth power is 243. \\n\\nThe sum of twelve and three is 15. \\n\\nMultiplying 243 by 15 gives 3645. \\n\\nFinally, squaring 3645 gives 164025.'}"
|
||||
" 'output': 'The result of taking 3 to the fifth power is 243. \\n\\nThe sum of twelve and three is 15. \\n\\nMultiplying 243 by 15 gives 3645. \\n\\nFinally, squaring 3645 gives 13286025.'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 1,
|
||||
"id": "10ad9224",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
@@ -1809,7 +1809,6 @@
|
||||
"cell_type": "markdown",
|
||||
"id": "0c69d84d",
|
||||
"metadata": {
|
||||
"jp-MarkdownHeadingCollapsed": true,
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
@@ -1891,7 +1890,6 @@
|
||||
"cell_type": "markdown",
|
||||
"id": "5da41b77",
|
||||
"metadata": {
|
||||
"jp-MarkdownHeadingCollapsed": true,
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
@@ -2149,6 +2147,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2ac1a8c7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## SingleStoreDB Semantic Cache\n",
|
||||
@@ -2173,6 +2172,353 @@
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7019c991-0101-4f9c-b212-5729a5471293",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Couchbase Caches\n",
|
||||
"\n",
|
||||
"Use [Couchbase](https://couchbase.com/) as a cache for prompts and responses."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d6aac680-ba32-4c19-8864-6471cf0e7d5a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Couchbase Cache\n",
|
||||
"\n",
|
||||
"The standard cache that looks for an exact match of the user prompt."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "9b4764e4-c75f-4185-b326-524287a826be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create couchbase connection object\n",
|
||||
"from datetime import timedelta\n",
|
||||
"\n",
|
||||
"from couchbase.auth import PasswordAuthenticator\n",
|
||||
"from couchbase.cluster import Cluster\n",
|
||||
"from couchbase.options import ClusterOptions\n",
|
||||
"from langchain_couchbase.cache import CouchbaseCache\n",
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
"\n",
|
||||
"COUCHBASE_CONNECTION_STRING = (\n",
|
||||
" \"couchbase://localhost\" # or \"couchbases://localhost\" if using TLS\n",
|
||||
")\n",
|
||||
"DB_USERNAME = \"Administrator\"\n",
|
||||
"DB_PASSWORD = \"Password\"\n",
|
||||
"\n",
|
||||
"auth = PasswordAuthenticator(DB_USERNAME, DB_PASSWORD)\n",
|
||||
"options = ClusterOptions(auth)\n",
|
||||
"cluster = Cluster(COUCHBASE_CONNECTION_STRING, options)\n",
|
||||
"\n",
|
||||
"# Wait until the cluster is ready for use.\n",
|
||||
"cluster.wait_until_ready(timedelta(seconds=5))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "4b5e73c5-92c1-4eab-84e2-77924ea9c123",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Specify the bucket, scope and collection to store the cached documents\n",
|
||||
"BUCKET_NAME = \"langchain-testing\"\n",
|
||||
"SCOPE_NAME = \"_default\"\n",
|
||||
"COLLECTION_NAME = \"_default\"\n",
|
||||
"\n",
|
||||
"set_llm_cache(\n",
|
||||
" CouchbaseCache(\n",
|
||||
" cluster=cluster,\n",
|
||||
" bucket_name=BUCKET_NAME,\n",
|
||||
" scope_name=SCOPE_NAME,\n",
|
||||
" collection_name=COLLECTION_NAME,\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "db8d28cc-8d93-47b4-8326-57a29a06fb3c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CPU times: user 22.2 ms, sys: 14 ms, total: 36.2 ms\n",
|
||||
"Wall time: 938 ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"\\n\\nWhy couldn't the bicycle stand up by itself? Because it was two-tired!\""
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"# The first time, it is not yet in the cache, so it should take longer\n",
|
||||
"llm.invoke(\"Tell me a joke\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "b470dc81-2e7f-4743-9435-ce9071394eea",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CPU times: user 53 ms, sys: 29 ms, total: 82 ms\n",
|
||||
"Wall time: 84.2 ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"\\n\\nWhy couldn't the bicycle stand up by itself? Because it was two-tired!\""
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"# The second time, it is in the cache, so it should be much faster\n",
|
||||
"llm.invoke(\"Tell me a joke\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "43626f33-d184-4260-b641-c9341cef5842",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Couchbase Semantic Cache\n",
|
||||
"Semantic caching allows users to retrieve cached prompts based on semantic similarity between the user input and previously cached inputs. Under the hood it uses Couchbase as both a cache and a vectorstore. This needs an appropriate Vector Search Index defined to work. Please look at the usage example on how to set up the index."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "6b470c03-d7fe-4270-89e1-638251619a53",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create Couchbase connection object\n",
|
||||
"from datetime import timedelta\n",
|
||||
"\n",
|
||||
"from couchbase.auth import PasswordAuthenticator\n",
|
||||
"from couchbase.cluster import Cluster\n",
|
||||
"from couchbase.options import ClusterOptions\n",
|
||||
"from langchain_couchbase.cache import CouchbaseSemanticCache\n",
|
||||
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
|
||||
"\n",
|
||||
"COUCHBASE_CONNECTION_STRING = (\n",
|
||||
" \"couchbase://localhost\" # or \"couchbases://localhost\" if using TLS\n",
|
||||
")\n",
|
||||
"DB_USERNAME = \"Administrator\"\n",
|
||||
"DB_PASSWORD = \"Password\"\n",
|
||||
"\n",
|
||||
"auth = PasswordAuthenticator(DB_USERNAME, DB_PASSWORD)\n",
|
||||
"options = ClusterOptions(auth)\n",
|
||||
"cluster = Cluster(COUCHBASE_CONNECTION_STRING, options)\n",
|
||||
"\n",
|
||||
"# Wait until the cluster is ready for use.\n",
|
||||
"cluster.wait_until_ready(timedelta(seconds=5))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f831bc4c-f330-4bd7-9b80-76771d91827e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Notes:\n",
|
||||
"- The search index for the semantic cache needs to be defined before using the semantic cache. \n",
|
||||
"- The optional parameter, `score_threshold` in the Semantic Cache that you can use to tune the results of the semantic search.\n",
|
||||
"\n",
|
||||
"### How to Import an Index to the Full Text Search service?\n",
|
||||
" - [Couchbase Server](https://docs.couchbase.com/server/current/search/import-search-index.html)\n",
|
||||
" - Click on Search -> Add Index -> Import\n",
|
||||
" - Copy the following Index definition in the Import screen\n",
|
||||
" - Click on Create Index to create the index.\n",
|
||||
" - [Couchbase Capella](https://docs.couchbase.com/cloud/search/import-search-index.html)\n",
|
||||
" - Copy the index definition to a new file `index.json`\n",
|
||||
" - Import the file in Capella using the instructions in the documentation.\n",
|
||||
" - Click on Create Index to create the index.\n",
|
||||
"\n",
|
||||
"#### Example index for the vector search. \n",
|
||||
" ```\n",
|
||||
" {\n",
|
||||
" \"type\": \"fulltext-index\",\n",
|
||||
" \"name\": \"langchain-testing._default.semantic-cache-index\",\n",
|
||||
" \"sourceType\": \"gocbcore\",\n",
|
||||
" \"sourceName\": \"langchain-testing\",\n",
|
||||
" \"planParams\": {\n",
|
||||
" \"maxPartitionsPerPIndex\": 1024,\n",
|
||||
" \"indexPartitions\": 16\n",
|
||||
" },\n",
|
||||
" \"params\": {\n",
|
||||
" \"doc_config\": {\n",
|
||||
" \"docid_prefix_delim\": \"\",\n",
|
||||
" \"docid_regexp\": \"\",\n",
|
||||
" \"mode\": \"scope.collection.type_field\",\n",
|
||||
" \"type_field\": \"type\"\n",
|
||||
" },\n",
|
||||
" \"mapping\": {\n",
|
||||
" \"analysis\": {},\n",
|
||||
" \"default_analyzer\": \"standard\",\n",
|
||||
" \"default_datetime_parser\": \"dateTimeOptional\",\n",
|
||||
" \"default_field\": \"_all\",\n",
|
||||
" \"default_mapping\": {\n",
|
||||
" \"dynamic\": true,\n",
|
||||
" \"enabled\": false\n",
|
||||
" },\n",
|
||||
" \"default_type\": \"_default\",\n",
|
||||
" \"docvalues_dynamic\": false,\n",
|
||||
" \"index_dynamic\": true,\n",
|
||||
" \"store_dynamic\": true,\n",
|
||||
" \"type_field\": \"_type\",\n",
|
||||
" \"types\": {\n",
|
||||
" \"_default.semantic-cache\": {\n",
|
||||
" \"dynamic\": false,\n",
|
||||
" \"enabled\": true,\n",
|
||||
" \"properties\": {\n",
|
||||
" \"embedding\": {\n",
|
||||
" \"dynamic\": false,\n",
|
||||
" \"enabled\": true,\n",
|
||||
" \"fields\": [\n",
|
||||
" {\n",
|
||||
" \"dims\": 1536,\n",
|
||||
" \"index\": true,\n",
|
||||
" \"name\": \"embedding\",\n",
|
||||
" \"similarity\": \"dot_product\",\n",
|
||||
" \"type\": \"vector\",\n",
|
||||
" \"vector_index_optimized_for\": \"recall\"\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"metadata\": {\n",
|
||||
" \"dynamic\": true,\n",
|
||||
" \"enabled\": true\n",
|
||||
" },\n",
|
||||
" \"text\": {\n",
|
||||
" \"dynamic\": false,\n",
|
||||
" \"enabled\": true,\n",
|
||||
" \"fields\": [\n",
|
||||
" {\n",
|
||||
" \"index\": true,\n",
|
||||
" \"name\": \"text\",\n",
|
||||
" \"store\": true,\n",
|
||||
" \"type\": \"text\"\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" \"store\": {\n",
|
||||
" \"indexType\": \"scorch\",\n",
|
||||
" \"segmentVersion\": 16\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" \"sourceParams\": {}\n",
|
||||
" }\n",
|
||||
" ```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "ae0766c8-ea34-4604-b0dc-cf2bbe8077f4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"BUCKET_NAME = \"langchain-testing\"\n",
|
||||
"SCOPE_NAME = \"_default\"\n",
|
||||
"COLLECTION_NAME = \"semantic-cache\"\n",
|
||||
"INDEX_NAME = \"semantic-cache-index\"\n",
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"\n",
|
||||
"cache = CouchbaseSemanticCache(\n",
|
||||
" cluster=cluster,\n",
|
||||
" embedding=embeddings,\n",
|
||||
" bucket_name=BUCKET_NAME,\n",
|
||||
" scope_name=SCOPE_NAME,\n",
|
||||
" collection_name=COLLECTION_NAME,\n",
|
||||
" index_name=INDEX_NAME,\n",
|
||||
" score_threshold=0.8,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"set_llm_cache(cache)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "a2e82743-10ea-4319-b43e-193475ae5449",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"The average lifespan of a dog is around 12 years, but this can vary depending on the breed, size, and overall health of the individual dog. Some smaller breeds may live longer, while larger breeds may have shorter lifespans. Proper care, diet, and exercise can also play a role in extending a dog's lifespan.\n",
|
||||
"CPU times: user 826 ms, sys: 2.46 s, total: 3.28 s\n",
|
||||
"Wall time: 2.87 s\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"# The first time, it is not yet in the cache, so it should take longer\n",
|
||||
"print(llm.invoke(\"How long do dogs live?\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "c36f4e29-d872-4334-a1f1-0e6d10c5d9f2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"The average lifespan of a dog is around 12 years, but this can vary depending on the breed, size, and overall health of the individual dog. Some smaller breeds may live longer, while larger breeds may have shorter lifespans. Proper care, diet, and exercise can also play a role in extending a dog's lifespan.\n",
|
||||
"CPU times: user 9.82 ms, sys: 2.61 ms, total: 12.4 ms\n",
|
||||
"Wall time: 311 ms\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"# The second time, it is in the cache, so it should be much faster\n",
|
||||
"print(llm.invoke(\"What is the expected lifespan of a dog?\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ae1f5e1c-085e-4998-9f2d-b5867d2c3d5b",
|
||||
@@ -2228,7 +2574,9 @@
|
||||
"| langchain_core.caches | [InMemoryCache](https://api.python.langchain.com/en/latest/caches/langchain_core.caches.InMemoryCache.html) |\n",
|
||||
"| langchain_elasticsearch.cache | [ElasticsearchCache](https://api.python.langchain.com/en/latest/cache/langchain_elasticsearch.cache.ElasticsearchCache.html) |\n",
|
||||
"| langchain_mongodb.cache | [MongoDBAtlasSemanticCache](https://api.python.langchain.com/en/latest/cache/langchain_mongodb.cache.MongoDBAtlasSemanticCache.html) |\n",
|
||||
"| langchain_mongodb.cache | [MongoDBCache](https://api.python.langchain.com/en/latest/cache/langchain_mongodb.cache.MongoDBCache.html) |\n"
|
||||
"| langchain_mongodb.cache | [MongoDBCache](https://api.python.langchain.com/en/latest/cache/langchain_mongodb.cache.MongoDBCache.html) |\n",
|
||||
"| langchain_couchbase.cache | [CouchbaseCache](https://api.python.langchain.com/en/latest/cache/langchain_couchbase.cache.CouchbaseCache.html) |\n",
|
||||
"| langchain_couchbase.cache | [CouchbaseSemanticCache](https://api.python.langchain.com/en/latest/cache/langchain_couchbase.cache.CouchbaseSemanticCache.html) |\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -2256,7 +2604,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install the package\n",
|
||||
"%pip install --upgrade --quiet dashscope"
|
||||
"%pip install --upgrade --quiet langchain-community dashscope"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -27,3 +27,65 @@ See a [usage example](/docs/integrations/document_loaders/couchbase).
|
||||
```python
|
||||
from langchain_community.document_loaders.couchbase import CouchbaseLoader
|
||||
```
|
||||
|
||||
## LLM Caches
|
||||
|
||||
### CouchbaseCache
|
||||
Use Couchbase as a cache for prompts and responses.
|
||||
|
||||
See a [usage example](/docs/integrations/llm_caching/#couchbase-cache).
|
||||
|
||||
To import this cache:
|
||||
```python
|
||||
from langchain_couchbase.cache import CouchbaseCache
|
||||
```
|
||||
|
||||
To use this cache with your LLMs:
|
||||
```python
|
||||
from langchain_core.globals import set_llm_cache
|
||||
|
||||
cluster = couchbase_cluster_connection_object
|
||||
|
||||
set_llm_cache(
|
||||
CouchbaseCache(
|
||||
cluster=cluster,
|
||||
bucket_name=BUCKET_NAME,
|
||||
scope_name=SCOPE_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
### CouchbaseSemanticCache
|
||||
Semantic caching allows users to retrieve cached prompts based on the semantic similarity between the user input and previously cached inputs. Under the hood it uses Couchbase as both a cache and a vectorstore.
|
||||
The CouchbaseSemanticCache needs a Search Index defined to work. Please look at the [usage example](/docs/integrations/vectorstores/couchbase) on how to set up the index.
|
||||
|
||||
See a [usage example](/docs/integrations/llm_caching/#couchbase-semantic-cache).
|
||||
|
||||
To import this cache:
|
||||
```python
|
||||
from langchain_couchbase.cache import CouchbaseSemanticCache
|
||||
```
|
||||
|
||||
To use this cache with your LLMs:
|
||||
```python
|
||||
from langchain_core.globals import set_llm_cache
|
||||
|
||||
# use any embedding provider...
|
||||
from langchain_openai.Embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
cluster = couchbase_cluster_connection_object
|
||||
|
||||
set_llm_cache(
|
||||
CouchbaseSemanticCache(
|
||||
cluster=cluster,
|
||||
embedding = embeddings,
|
||||
bucket_name=BUCKET_NAME,
|
||||
scope_name=SCOPE_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
index_name=INDEX_NAME,
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
@@ -61,7 +61,7 @@ When ready to deploy, you can self-host models with NVIDIA NIM—which is includ
|
||||
```python
|
||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings, NVIDIARerank
|
||||
|
||||
# connect to an chat NIM running at localhost:8000, specifyig a specific model
|
||||
# connect to a chat NIM running at localhost:8000, specifying a model
|
||||
llm = ChatNVIDIA(base_url="http://localhost:8000/v1", model="meta/llama3-8b-instruct")
|
||||
|
||||
# connect to an embedding NIM running at localhost:8080
|
||||
|
||||
@@ -202,7 +202,7 @@ Prem Templates are also available for Streaming too.
|
||||
|
||||
## Prem Embeddings
|
||||
|
||||
In this section we are going to dicuss how we can get access to different embedding model using `PremEmbeddings` with LangChain. Lets start by importing our modules and setting our API Key.
|
||||
In this section we cover how we can get access to different embedding models using `PremEmbeddings` with LangChain. Let's start by importing our modules and setting our API Key.
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
@@ -309,9 +309,9 @@
|
||||
"documents = TextLoader(\"../../how_to/state_of_the_union.txt\").load()\n",
|
||||
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n",
|
||||
"texts = text_splitter.split_documents(documents)\n",
|
||||
"retriever = FAISS.from_documents(texts, CohereEmbeddings()).as_retriever(\n",
|
||||
" search_kwargs={\"k\": 20}\n",
|
||||
")\n",
|
||||
"retriever = FAISS.from_documents(\n",
|
||||
" texts, CohereEmbeddings(model=\"embed-english-v3.0\")\n",
|
||||
").as_retriever(search_kwargs={\"k\": 20})\n",
|
||||
"\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = retriever.invoke(query)\n",
|
||||
@@ -324,7 +324,8 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Doing reranking with CohereRerank\n",
|
||||
"Now let's wrap our base retriever with a `ContextualCompressionRetriever`. We'll add an `CohereRerank`, uses the Cohere rerank endpoint to rerank the returned results."
|
||||
"Now let's wrap our base retriever with a `ContextualCompressionRetriever`. We'll add an `CohereRerank`, uses the Cohere rerank endpoint to rerank the returned results.\n",
|
||||
"Do note that it is mandatory to specify the model name in CohereRerank!"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -339,7 +340,7 @@
|
||||
"from langchain_community.llms import Cohere\n",
|
||||
"\n",
|
||||
"llm = Cohere(temperature=0)\n",
|
||||
"compressor = CohereRerank()\n",
|
||||
"compressor = CohereRerank(model=\"rerank-english-v3.0\")\n",
|
||||
"compression_retriever = ContextualCompressionRetriever(\n",
|
||||
" base_compressor=compressor, base_retriever=retriever\n",
|
||||
")\n",
|
||||
|
||||
@@ -40,7 +40,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = CohereEmbeddings(model=\"embed-english-light-v3.0\")"
|
||||
"embeddings = CohereEmbeddings(\n",
|
||||
" model=\"embed-english-light-v3.0\"\n",
|
||||
") # It is mandatory to pass a model parameter to initialize the CohereEmbeddings object"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -169,6 +169,23 @@
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Specify additional properties for the Azure client such as the following https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/README.md#configurations\n",
|
||||
"vector_store: AzureSearch = AzureSearch(\n",
|
||||
" azure_search_endpoint=vector_store_address,\n",
|
||||
" azure_search_key=vector_store_password,\n",
|
||||
" index_name=index_name,\n",
|
||||
" embedding_function=embeddings.embed_query,\n",
|
||||
" # Configure max retries for the Azure client\n",
|
||||
" additional_search_client_options={\"retry_total\": 4},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
||||
@@ -78,7 +78,7 @@
|
||||
"# See docker command above to launch a postgres instance with pgvector enabled.\n",
|
||||
"connection = \"postgresql+psycopg://langchain:langchain@localhost:6024/langchain\" # Uses psycopg3!\n",
|
||||
"collection_name = \"my_docs\"\n",
|
||||
"embeddings = CohereEmbeddings()\n",
|
||||
"embeddings = CohereEmbeddings(model=\"embed-english-v3.0\")\n",
|
||||
"\n",
|
||||
"vectorstore = PGVector(\n",
|
||||
" embeddings=embeddings,\n",
|
||||
|
||||
@@ -107,7 +107,7 @@
|
||||
"```\n",
|
||||
"## Preview\n",
|
||||
"\n",
|
||||
"In this guide we’ll build a QA app over as website. The specific website we will use is the [LLM Powered Autonomous\n",
|
||||
"In this guide we’ll build an app that answers questions about the content of a website. The specific website we will use is the [LLM Powered Autonomous\n",
|
||||
"Agents](https://lilianweng.github.io/posts/2023-06-23-agent/) blog post\n",
|
||||
"by Lilian Weng, which allows us to ask questions about the contents of\n",
|
||||
"the post.\n",
|
||||
|
||||
@@ -25,7 +25,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_API_BASE = "https://api.endpoints.anyscale.com/v1"
|
||||
DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf"
|
||||
DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
|
||||
|
||||
class ChatAnyscale(ChatOpenAI):
|
||||
|
||||
@@ -141,9 +141,8 @@ class CustomOpenAIChatContentFormatter(ContentFormatterBase):
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return ChatGeneration(
|
||||
message=BaseMessage(
|
||||
message=AIMessage(
|
||||
content=choice.strip(),
|
||||
type="assistant",
|
||||
),
|
||||
generation_info=None,
|
||||
)
|
||||
@@ -158,7 +157,9 @@ class CustomOpenAIChatContentFormatter(ContentFormatterBase):
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return ChatGeneration(
|
||||
message=BaseMessage(
|
||||
message=AIMessage(content=choice["message"]["content"].strip())
|
||||
if choice["message"]["role"] == "assistant"
|
||||
else BaseMessage(
|
||||
content=choice["message"]["content"].strip(),
|
||||
type=choice["message"]["role"],
|
||||
),
|
||||
|
||||
@@ -48,6 +48,7 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
@@ -96,7 +97,7 @@ def _parse_tool_calling(tool_call: dict) -> ToolCall:
|
||||
name = tool_call["function"].get("name", "")
|
||||
args = json.loads(tool_call["function"]["arguments"])
|
||||
id = tool_call.get("id")
|
||||
return ToolCall(name=name, args=args, id=id)
|
||||
return create_tool_call(name=name, args=args, id=id)
|
||||
|
||||
|
||||
def _convert_to_tool_calling(tool_call: ToolCall) -> Dict[str, Any]:
|
||||
|
||||
@@ -36,9 +36,11 @@ from langchain_core.messages import (
|
||||
InvalidToolCall,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolCallChunk,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
@@ -63,7 +65,7 @@ def _result_to_chunked_message(generated_result: ChatResult) -> ChatGenerationCh
|
||||
message = generated_result.generations[0].message
|
||||
if isinstance(message, AIMessage) and message.tool_calls is not None:
|
||||
tool_call_chunks = [
|
||||
ToolCallChunk(
|
||||
create_tool_call_chunk(
|
||||
name=tool_call["name"],
|
||||
args=json.dumps(tool_call["args"]),
|
||||
id=tool_call["id"],
|
||||
@@ -189,7 +191,7 @@ def _extract_tool_calls_from_edenai_response(
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
try:
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
create_tool_call(
|
||||
name=raw_tool_call["name"],
|
||||
args=json.loads(raw_tool_call["arguments"]),
|
||||
id=raw_tool_call["id"],
|
||||
@@ -197,7 +199,7 @@ def _extract_tool_calls_from_edenai_response(
|
||||
)
|
||||
except json.JSONDecodeError as exc:
|
||||
invalid_tool_calls.append(
|
||||
InvalidToolCall(
|
||||
create_invalid_tool_call(
|
||||
name=raw_tool_call.get("name"),
|
||||
args=raw_tool_call.get("arguments"),
|
||||
id=raw_tool_call.get("id"),
|
||||
|
||||
@@ -144,7 +144,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
elif (
|
||||
isinstance(temp_image_url, dict) and "url" in temp_image_url
|
||||
):
|
||||
image_url = temp_image_url
|
||||
image_url = temp_image_url["url"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only string image_url or dict with string 'url' "
|
||||
|
||||
@@ -60,7 +60,7 @@ class HuggingFaceCrossEncoder(BaseModel, BaseCrossEncoder):
|
||||
List of scores, one for each pair.
|
||||
"""
|
||||
scores = self.client.predict(text_pairs)
|
||||
# Somes models e.g bert-multilingual-passage-reranking-msmarco
|
||||
# Some models e.g bert-multilingual-passage-reranking-msmarco
|
||||
# gives two score not_relevant and relevant as compare with the query.
|
||||
if len(scores.shape) > 1: # we are going to get the relevant scores
|
||||
scores = map(lambda x: x[1], scores)
|
||||
|
||||
@@ -60,7 +60,7 @@ class AscendEmbeddings(Embeddings, BaseModel):
|
||||
raise ValueError("model_path is required")
|
||||
if not os.access(values["model_path"], os.F_OK):
|
||||
raise FileNotFoundError(
|
||||
f"Unabled to find valid model path in [{values['model_path']}]"
|
||||
f"Unable to find valid model path in [{values['model_path']}]"
|
||||
)
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
@@ -555,10 +555,11 @@ class Neo4jGraph(GraphStore):
|
||||
el["labelsOrTypes"] == [BASE_ENTITY_LABEL]
|
||||
and el["properties"] == ["id"]
|
||||
for el in self.structured_schema.get("metadata", {}).get(
|
||||
"constraint"
|
||||
"constraint", []
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if not constraint_exists:
|
||||
# Create constraint
|
||||
self.query(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
@@ -11,7 +12,6 @@ from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.utils import (
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
)
|
||||
@@ -177,16 +177,17 @@ class HuggingFaceEndpoint(LLM):
|
||||
"Could not import huggingface_hub python package. "
|
||||
"Please install it with `pip install huggingface_hub`."
|
||||
)
|
||||
try:
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
login(token=huggingfacehub_api_token)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not authenticate with huggingface_hub. "
|
||||
"Please check your API token."
|
||||
) from e
|
||||
huggingfacehub_api_token = values["huggingfacehub_api_token"] or os.getenv(
|
||||
"HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
if huggingfacehub_api_token is not None:
|
||||
try:
|
||||
login(token=huggingfacehub_api_token)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not authenticate with huggingface_hub. "
|
||||
"Please check your API token."
|
||||
) from e
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, InferenceClient
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ class SQLStore(BaseStore[str, bytes]):
|
||||
from langchain_rag.storage import SQLStore
|
||||
|
||||
# Instantiate the SQLStore with the root path
|
||||
sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:")
|
||||
sql_store = SQLStore(namespace="test", db_url="sqlite://:memory:")
|
||||
|
||||
# Set values for keys
|
||||
sql_store.mset([("key1", b"value1"), ("key2", b"value2")])
|
||||
|
||||
@@ -80,7 +80,7 @@ class SemanticScholarAPIWrapper(BaseModel):
|
||||
f"Published year: {getattr(item, 'year', None)}\n"
|
||||
f"Title: {getattr(item, 'title', None)}\n"
|
||||
f"Authors: {authors}\n"
|
||||
f"Astract: {getattr(item, 'abstract', None)}\n"
|
||||
f"Abstract: {getattr(item, 'abstract', None)}\n"
|
||||
)
|
||||
|
||||
if documents:
|
||||
|
||||
@@ -86,6 +86,7 @@ def _get_search_client(
|
||||
user_agent: Optional[str] = "langchain",
|
||||
cors_options: Optional[CorsOptions] = None,
|
||||
async_: bool = False,
|
||||
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[SearchClient, AsyncSearchClient]:
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
@@ -109,6 +110,7 @@ def _get_search_client(
|
||||
VectorSearchProfile,
|
||||
)
|
||||
|
||||
additional_search_client_options = additional_search_client_options or {}
|
||||
default_fields = default_fields or []
|
||||
if key is None:
|
||||
credential = DefaultAzureCredential()
|
||||
@@ -225,6 +227,7 @@ def _get_search_client(
|
||||
index_name=index_name,
|
||||
credential=credential,
|
||||
user_agent=user_agent,
|
||||
**additional_search_client_options,
|
||||
)
|
||||
else:
|
||||
return AsyncSearchClient(
|
||||
@@ -232,6 +235,7 @@ def _get_search_client(
|
||||
index_name=index_name,
|
||||
credential=credential,
|
||||
user_agent=user_agent,
|
||||
**additional_search_client_options,
|
||||
)
|
||||
|
||||
|
||||
@@ -256,6 +260,7 @@ class AzureSearch(VectorStore):
|
||||
cors_options: Optional[CorsOptions] = None,
|
||||
*,
|
||||
vector_search_dimensions: Optional[int] = None,
|
||||
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
@@ -320,6 +325,22 @@ class AzureSearch(VectorStore):
|
||||
default_fields=default_fields,
|
||||
user_agent=user_agent,
|
||||
cors_options=cors_options,
|
||||
additional_search_client_options=additional_search_client_options,
|
||||
)
|
||||
self.async_client = _get_search_client(
|
||||
azure_search_endpoint,
|
||||
azure_search_key,
|
||||
index_name,
|
||||
semantic_configuration_name=semantic_configuration_name,
|
||||
fields=fields,
|
||||
vector_search=vector_search,
|
||||
semantic_configurations=semantic_configurations,
|
||||
scoring_profiles=scoring_profiles,
|
||||
default_scoring_profile=default_scoring_profile,
|
||||
default_fields=default_fields,
|
||||
user_agent=user_agent,
|
||||
cors_options=cors_options,
|
||||
async_=True,
|
||||
)
|
||||
self.search_type = search_type
|
||||
self.semantic_configuration_name = semantic_configuration_name
|
||||
@@ -338,23 +359,6 @@ class AzureSearch(VectorStore):
|
||||
self._user_agent = user_agent
|
||||
self._cors_options = cors_options
|
||||
|
||||
def _async_client(self) -> AsyncSearchClient:
|
||||
return _get_search_client(
|
||||
self._azure_search_endpoint,
|
||||
self._azure_search_key,
|
||||
self._index_name,
|
||||
semantic_configuration_name=self._semantic_configuration_name,
|
||||
fields=self._fields,
|
||||
vector_search=self._vector_search,
|
||||
semantic_configurations=self._semantic_configurations,
|
||||
scoring_profiles=self._scoring_profiles,
|
||||
default_scoring_profile=self._default_scoring_profile,
|
||||
default_fields=self._default_fields,
|
||||
user_agent=self._user_agent,
|
||||
cors_options=self._cors_options,
|
||||
async_=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
# TODO: Support embedding object directly
|
||||
@@ -513,7 +517,7 @@ class AzureSearch(VectorStore):
|
||||
ids.append(key)
|
||||
# Upload data in batches
|
||||
if len(data) == MAX_UPLOAD_BATCH_SIZE:
|
||||
async with self._async_client() as async_client:
|
||||
async with self.async_client as async_client:
|
||||
response = await async_client.upload_documents(documents=data)
|
||||
# Check if all documents were successfully uploaded
|
||||
if not all(r.succeeded for r in response):
|
||||
@@ -526,7 +530,7 @@ class AzureSearch(VectorStore):
|
||||
return ids
|
||||
|
||||
# Upload data to index
|
||||
async with self._async_client() as async_client:
|
||||
async with self.async_client as async_client:
|
||||
response = await async_client.upload_documents(documents=data)
|
||||
# Check if all documents were successfully uploaded
|
||||
if all(r.succeeded for r in response):
|
||||
@@ -561,7 +565,7 @@ class AzureSearch(VectorStore):
|
||||
False otherwise.
|
||||
"""
|
||||
if ids:
|
||||
async with self._async_client() as async_client:
|
||||
async with self.async_client as async_client:
|
||||
res = await async_client.delete_documents([{"id": i} for i in ids])
|
||||
return len(res) > 0
|
||||
else:
|
||||
@@ -739,11 +743,11 @@ class AzureSearch(VectorStore):
|
||||
to the query and score for each
|
||||
"""
|
||||
embedding = await self._aembed_query(query)
|
||||
docs, scores, _ = await self._asimple_search(
|
||||
results = await self._asimple_search(
|
||||
embedding, "", k, filters=filters, **kwargs
|
||||
)
|
||||
|
||||
return list(zip(docs, scores))
|
||||
return _results_to_documents(results)
|
||||
|
||||
def max_marginal_relevance_search_with_score(
|
||||
self,
|
||||
@@ -807,14 +811,12 @@ class AzureSearch(VectorStore):
|
||||
to the query and score for each
|
||||
"""
|
||||
embedding = await self._aembed_query(query)
|
||||
docs, scores, vectors = await self._asimple_search(
|
||||
results = await self._asimple_search(
|
||||
embedding, "", fetch_k, filters=filters, **kwargs
|
||||
)
|
||||
|
||||
return await self._areorder_results_with_maximal_marginal_relevance(
|
||||
docs,
|
||||
scores,
|
||||
vectors,
|
||||
return await _areorder_results_with_maximal_marginal_relevance(
|
||||
results,
|
||||
query_embedding=np.array(embedding),
|
||||
lambda_mult=lambda_mult,
|
||||
k=k,
|
||||
@@ -890,11 +892,11 @@ class AzureSearch(VectorStore):
|
||||
"""
|
||||
|
||||
embedding = await self._aembed_query(query)
|
||||
docs, scores, _ = await self._asimple_search(
|
||||
results = await self._asimple_search(
|
||||
embedding, query, k, filters=filters, **kwargs
|
||||
)
|
||||
|
||||
return list(zip(docs, scores))
|
||||
return _results_to_documents(results)
|
||||
|
||||
def hybrid_search_with_relevance_scores(
|
||||
self,
|
||||
@@ -992,14 +994,12 @@ class AzureSearch(VectorStore):
|
||||
"""
|
||||
|
||||
embedding = await self._aembed_query(query)
|
||||
docs, scores, vectors = await self._asimple_search(
|
||||
results = await self._asimple_search(
|
||||
embedding, query, fetch_k, filters=filters, **kwargs
|
||||
)
|
||||
|
||||
return await self._areorder_results_with_maximal_marginal_relevance(
|
||||
docs,
|
||||
scores,
|
||||
vectors,
|
||||
return await _areorder_results_with_maximal_marginal_relevance(
|
||||
results,
|
||||
query_embedding=np.array(embedding),
|
||||
lambda_mult=lambda_mult,
|
||||
k=k,
|
||||
@@ -1049,7 +1049,7 @@ class AzureSearch(VectorStore):
|
||||
*,
|
||||
filters: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[List[Document], List[float], List[List[float]]]:
|
||||
) -> SearchItemPaged[dict]:
|
||||
"""Perform vector or hybrid search in the Azure search index.
|
||||
|
||||
Args:
|
||||
@@ -1063,8 +1063,8 @@ class AzureSearch(VectorStore):
|
||||
"""
|
||||
from azure.search.documents.models import VectorizedQuery
|
||||
|
||||
async with self._async_client() as async_client:
|
||||
results = await async_client.search(
|
||||
async with self.async_client as async_client:
|
||||
return await async_client.search(
|
||||
search_text=text_query,
|
||||
vector_queries=[
|
||||
VectorizedQuery(
|
||||
@@ -1077,18 +1077,6 @@ class AzureSearch(VectorStore):
|
||||
top=k,
|
||||
**kwargs,
|
||||
)
|
||||
docs = [
|
||||
(
|
||||
_result_to_document(result),
|
||||
float(result["@search.score"]),
|
||||
result[FIELDS_CONTENT_VECTOR],
|
||||
)
|
||||
async for result in results
|
||||
]
|
||||
if not docs:
|
||||
raise ValueError(f"No {docs=}")
|
||||
documents, scores, vectors = map(list, zip(*docs))
|
||||
return documents, scores, vectors
|
||||
|
||||
def semantic_hybrid_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
@@ -1300,7 +1288,7 @@ class AzureSearch(VectorStore):
|
||||
from azure.search.documents.models import VectorizedQuery
|
||||
|
||||
vector = await self._aembed_query(query)
|
||||
async with self._async_client() as async_client:
|
||||
async with self.async_client as async_client:
|
||||
results = await async_client.search(
|
||||
search_text=query,
|
||||
vector_queries=[
|
||||
@@ -1475,30 +1463,6 @@ class AzureSearch(VectorStore):
|
||||
azure_search.add_embeddings(text_embeddings, metadatas, **kwargs)
|
||||
return azure_search
|
||||
|
||||
async def _areorder_results_with_maximal_marginal_relevance(
|
||||
self,
|
||||
documents: List[Document],
|
||||
scores: List[float],
|
||||
vectors: List[List[float]],
|
||||
query_embedding: np.ndarray,
|
||||
lambda_mult: float = 0.5,
|
||||
k: int = 4,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
# Get the new order of results.
|
||||
new_ordering = maximal_marginal_relevance(
|
||||
query_embedding, vectors, k=k, lambda_mult=lambda_mult
|
||||
)
|
||||
|
||||
# Reorder the values and return.
|
||||
ret: List[Tuple[Document, float]] = []
|
||||
for x in new_ordering:
|
||||
# Function can return -1 index
|
||||
if x == -1:
|
||||
break
|
||||
ret.append((documents[x], scores[x])) # type: ignore
|
||||
|
||||
return ret
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> AzureSearchVectorStoreRetriever: # type: ignore
|
||||
"""Return AzureSearchVectorStoreRetriever initialized from this VectorStore.
|
||||
|
||||
@@ -1666,6 +1630,39 @@ def _results_to_documents(
|
||||
return docs
|
||||
|
||||
|
||||
async def _areorder_results_with_maximal_marginal_relevance(
|
||||
results: SearchItemPaged[Dict],
|
||||
query_embedding: np.ndarray,
|
||||
lambda_mult: float = 0.5,
|
||||
k: int = 4,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
# Convert results to Document objects
|
||||
docs = [
|
||||
(
|
||||
_result_to_document(result),
|
||||
float(result["@search.score"]),
|
||||
result[FIELDS_CONTENT_VECTOR],
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
documents, scores, vectors = map(list, zip(*docs))
|
||||
|
||||
# Get the new order of results.
|
||||
new_ordering = maximal_marginal_relevance(
|
||||
query_embedding, vectors, k=k, lambda_mult=lambda_mult
|
||||
)
|
||||
|
||||
# Reorder the values and return.
|
||||
ret: List[Tuple[Document, float]] = []
|
||||
for x in new_ordering:
|
||||
# Function can return -1 index
|
||||
if x == -1:
|
||||
break
|
||||
ret.append((documents[x], scores[x])) # type: ignore
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def _reorder_results_with_maximal_marginal_relevance(
|
||||
results: SearchItemPaged[Dict],
|
||||
query_embedding: np.ndarray,
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain_community.tools.zenguard.tool import Detector, ZenGuardTool
|
||||
@pytest.fixture()
|
||||
def zenguard_tool() -> ZenGuardTool:
|
||||
if os.getenv("ZENGUARD_API_KEY") is None:
|
||||
raise ValueError("ZENGUARD_API_KEY is not set in environment varibale")
|
||||
raise ValueError("ZENGUARD_API_KEY is not set in environment variable")
|
||||
return ZenGuardTool()
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ PAGE_1 = """
|
||||
Hello.
|
||||
<a href="relative">Relative</a>
|
||||
<a href="/relative-base">Relative base.</a>
|
||||
<a href="http://cnn.com">Aboslute</a>
|
||||
<a href="http://cnn.com">Absolute</a>
|
||||
<a href="//same.foo">Test</a>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -39,7 +39,7 @@ def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]:
|
||||
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
|
||||
"""Test that tools defined in this repo accept a run manager argument."""
|
||||
# This wouldn't be necessary if the BaseTool had a strict API.
|
||||
if cls._run is not BaseTool._arun:
|
||||
if cls._run is not BaseTool._run:
|
||||
run_func = cls._run
|
||||
params = inspect.signature(run_func).parameters
|
||||
assert "run_manager" in params
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -121,12 +121,15 @@ def mock_default_index(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
)
|
||||
|
||||
|
||||
def create_vector_store() -> AzureSearch:
|
||||
def create_vector_store(
|
||||
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
||||
) -> AzureSearch:
|
||||
return AzureSearch(
|
||||
azure_search_endpoint=DEFAULT_ENDPOINT,
|
||||
azure_search_key=DEFAULT_KEY,
|
||||
index_name=DEFAULT_INDEX_NAME,
|
||||
embedding_function=DEFAULT_EMBEDDING_MODEL,
|
||||
additional_search_client_options=additional_search_client_options,
|
||||
)
|
||||
|
||||
|
||||
@@ -168,3 +171,20 @@ def test_init_new_index() -> None:
|
||||
assert json.dumps(created_index.as_dict()) == json.dumps(
|
||||
mock_default_index().as_dict()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("azure.search.documents")
|
||||
def test_additional_search_options() -> None:
|
||||
from azure.search.documents.indexes import SearchIndexClient
|
||||
|
||||
def mock_create_index() -> None:
|
||||
pytest.fail("Should not create index in this test")
|
||||
|
||||
with patch.multiple(
|
||||
SearchIndexClient, get_index=mock_default_index, create_index=mock_create_index
|
||||
):
|
||||
vector_store = create_vector_store(
|
||||
additional_search_client_options={"api_version": "test"}
|
||||
)
|
||||
assert vector_store.client is not None
|
||||
assert vector_store.client._api_version == "test"
|
||||
|
||||
@@ -15,7 +15,7 @@ PathLike = Union[str, PurePath]
|
||||
class BaseMedia(Serializable):
|
||||
"""Use to represent media content.
|
||||
|
||||
Media objets can be used to represent raw data, such as text or binary data.
|
||||
Media objects can be used to represent raw data, such as text or binary data.
|
||||
|
||||
LangChain Media objects allow associating metadata and an optional identifier
|
||||
with the content.
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
@@ -130,11 +131,14 @@ def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]:
|
||||
)
|
||||
|
||||
|
||||
@beta(message="Added in version 0.2.14 of langchain_core. API subject to change.")
|
||||
class GraphVectorStore(VectorStore):
|
||||
"""A hybrid vector-and-graph graph store.
|
||||
|
||||
Document chunks support vector-similarity search as well as edges linking
|
||||
chunks based on structural and semantic properties.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -15,11 +15,18 @@ from langchain_core.messages.tool import (
|
||||
default_tool_chunk_parser,
|
||||
default_tool_parser,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
invalid_tool_call as create_invalid_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call as create_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call_chunk as create_tool_call_chunk,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.json import (
|
||||
parse_partial_json,
|
||||
)
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
|
||||
|
||||
class UsageMetadata(TypedDict):
|
||||
@@ -106,24 +113,55 @@ class AIMessage(BaseMessage):
|
||||
|
||||
@root_validator(pre=True)
|
||||
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
|
||||
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
|
||||
tool_calls = (
|
||||
values.get("tool_calls")
|
||||
or values.get("invalid_tool_calls")
|
||||
or values.get("tool_call_chunks")
|
||||
check_additional_kwargs = not any(
|
||||
values.get(k)
|
||||
for k in ("tool_calls", "invalid_tool_calls", "tool_call_chunks")
|
||||
)
|
||||
if raw_tool_calls and not tool_calls:
|
||||
if check_additional_kwargs and (
|
||||
raw_tool_calls := values.get("additional_kwargs", {}).get("tool_calls")
|
||||
):
|
||||
try:
|
||||
if issubclass(cls, AIMessageChunk): # type: ignore
|
||||
values["tool_call_chunks"] = default_tool_chunk_parser(
|
||||
raw_tool_calls
|
||||
)
|
||||
else:
|
||||
tool_calls, invalid_tool_calls = default_tool_parser(raw_tool_calls)
|
||||
values["tool_calls"] = tool_calls
|
||||
values["invalid_tool_calls"] = invalid_tool_calls
|
||||
parsed_tool_calls, parsed_invalid_tool_calls = default_tool_parser(
|
||||
raw_tool_calls
|
||||
)
|
||||
values["tool_calls"] = parsed_tool_calls
|
||||
values["invalid_tool_calls"] = parsed_invalid_tool_calls
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Ensure "type" is properly set on all tool call-like dicts.
|
||||
if tool_calls := values.get("tool_calls"):
|
||||
updated: List = []
|
||||
for tc in tool_calls:
|
||||
updated.append(
|
||||
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
|
||||
)
|
||||
values["tool_calls"] = updated
|
||||
if invalid_tool_calls := values.get("invalid_tool_calls"):
|
||||
updated = []
|
||||
for tc in invalid_tool_calls:
|
||||
updated.append(
|
||||
create_invalid_tool_call(
|
||||
**{k: v for k, v in tc.items() if k != "type"}
|
||||
)
|
||||
)
|
||||
values["invalid_tool_calls"] = updated
|
||||
|
||||
if tool_call_chunks := values.get("tool_call_chunks"):
|
||||
updated = []
|
||||
for tc in tool_call_chunks:
|
||||
updated.append(
|
||||
create_tool_call_chunk(
|
||||
**{k: v for k, v in tc.items() if k != "type"}
|
||||
)
|
||||
)
|
||||
values["tool_call_chunks"] = updated
|
||||
|
||||
return values
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
@@ -216,7 +254,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
if not values["tool_call_chunks"]:
|
||||
if values["tool_calls"]:
|
||||
values["tool_call_chunks"] = [
|
||||
ToolCallChunk(
|
||||
create_tool_call_chunk(
|
||||
name=tc["name"],
|
||||
args=json.dumps(tc["args"]),
|
||||
id=tc["id"],
|
||||
@@ -228,7 +266,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
tool_call_chunks = values.get("tool_call_chunks", [])
|
||||
tool_call_chunks.extend(
|
||||
[
|
||||
ToolCallChunk(
|
||||
create_tool_call_chunk(
|
||||
name=tc["name"], args=tc["args"], id=tc["id"], index=None
|
||||
)
|
||||
for tc in values["invalid_tool_calls"]
|
||||
@@ -244,7 +282,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {}
|
||||
if isinstance(args_, dict):
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
create_tool_call(
|
||||
name=chunk["name"] or "",
|
||||
args=args_,
|
||||
id=chunk["id"],
|
||||
@@ -254,7 +292,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
raise ValueError("Malformed args.")
|
||||
except Exception:
|
||||
invalid_tool_calls.append(
|
||||
InvalidToolCall(
|
||||
create_invalid_tool_call(
|
||||
name=chunk["name"],
|
||||
args=chunk["args"],
|
||||
id=chunk["id"],
|
||||
@@ -297,7 +335,7 @@ def add_ai_message_chunks(
|
||||
left.tool_call_chunks, *(o.tool_call_chunks for o in others)
|
||||
):
|
||||
tool_call_chunks = [
|
||||
ToolCallChunk(
|
||||
create_tool_call_chunk(
|
||||
name=rtc.get("name"),
|
||||
args=rtc.get("args"),
|
||||
index=rtc.get("index"),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
|
||||
from langchain_core.utils._merge import merge_dicts, merge_obj
|
||||
@@ -21,8 +21,11 @@ class ToolMessage(BaseMessage):
|
||||
|
||||
ToolMessage(content='42', tool_call_id='call_Jja7J89XsjrOLA5r!MEOW!SL')
|
||||
|
||||
|
||||
Example: A ToolMessage where only part of the tool output is sent to the model
|
||||
and the full output is passed in to raw_output.
|
||||
and the full output is passed in to artifact.
|
||||
|
||||
.. versionadded:: 0.2.17
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -36,7 +39,7 @@ class ToolMessage(BaseMessage):
|
||||
|
||||
ToolMessage(
|
||||
content=tool_output["stdout"],
|
||||
raw_output=tool_output,
|
||||
artifact=tool_output,
|
||||
tool_call_id='call_Jja7J89XsjrOLA5r!MEOW!SL',
|
||||
)
|
||||
|
||||
@@ -54,12 +57,14 @@ class ToolMessage(BaseMessage):
|
||||
type: Literal["tool"] = "tool"
|
||||
"""The type of the message (used for serialization). Defaults to "tool"."""
|
||||
|
||||
raw_output: Any = None
|
||||
"""The raw output of the tool.
|
||||
artifact: Any = None
|
||||
"""Artifact of the Tool execution which is not meant to be sent to the model.
|
||||
|
||||
**Not part of the payload sent to the model.** Should only be specified if it is
|
||||
different from the message content, i.e. if only a subset of the full tool output
|
||||
is being passed as message content.
|
||||
Should only be specified if it is different from the message content, e.g. if only
|
||||
a subset of the full tool output is being passed as message content but the full
|
||||
output is needed in other parts of the code.
|
||||
|
||||
.. versionadded:: 0.2.17
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@@ -106,7 +111,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
tool_call_id=self.tool_call_id,
|
||||
content=merge_content(self.content, other.content),
|
||||
raw_output=merge_obj(self.raw_output, other.raw_output),
|
||||
artifact=merge_obj(self.artifact, other.artifact),
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
@@ -146,6 +151,11 @@ class ToolCall(TypedDict):
|
||||
An identifier is needed to associate a tool call request with a tool
|
||||
call result in events when multiple concurrent tool calls are made.
|
||||
"""
|
||||
type: NotRequired[Literal["tool_call"]]
|
||||
|
||||
|
||||
def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall:
|
||||
return ToolCall(name=name, args=args, id=id, type="tool_call")
|
||||
|
||||
|
||||
class ToolCallChunk(TypedDict):
|
||||
@@ -176,6 +186,19 @@ class ToolCallChunk(TypedDict):
|
||||
"""An identifier associated with the tool call."""
|
||||
index: Optional[int]
|
||||
"""The index of the tool call in a sequence."""
|
||||
type: NotRequired[Literal["tool_call_chunk"]]
|
||||
|
||||
|
||||
def tool_call_chunk(
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
args: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
index: Optional[int] = None,
|
||||
) -> ToolCallChunk:
|
||||
return ToolCallChunk(
|
||||
name=name, args=args, id=id, index=index, type="tool_call_chunk"
|
||||
)
|
||||
|
||||
|
||||
class InvalidToolCall(TypedDict):
|
||||
@@ -193,6 +216,19 @@ class InvalidToolCall(TypedDict):
|
||||
"""An identifier associated with the tool call."""
|
||||
error: Optional[str]
|
||||
"""An error message associated with the tool call."""
|
||||
type: NotRequired[Literal["invalid_tool_call"]]
|
||||
|
||||
|
||||
def invalid_tool_call(
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
args: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> InvalidToolCall:
|
||||
return InvalidToolCall(
|
||||
name=name, args=args, id=id, error=error, type="invalid_tool_call"
|
||||
)
|
||||
|
||||
|
||||
def default_tool_parser(
|
||||
@@ -201,25 +237,25 @@ def default_tool_parser(
|
||||
"""Best-effort parsing of tools."""
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
for tool_call in raw_tool_calls:
|
||||
if "function" not in tool_call:
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
if "function" not in raw_tool_call:
|
||||
continue
|
||||
else:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_name = raw_tool_call["function"]["name"]
|
||||
try:
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
parsed = ToolCall(
|
||||
function_args = json.loads(raw_tool_call["function"]["arguments"])
|
||||
parsed = tool_call(
|
||||
name=function_name or "",
|
||||
args=function_args or {},
|
||||
id=tool_call.get("id"),
|
||||
id=raw_tool_call.get("id"),
|
||||
)
|
||||
tool_calls.append(parsed)
|
||||
except json.JSONDecodeError:
|
||||
invalid_tool_calls.append(
|
||||
InvalidToolCall(
|
||||
invalid_tool_call(
|
||||
name=function_name,
|
||||
args=tool_call["function"]["arguments"],
|
||||
id=tool_call.get("id"),
|
||||
args=raw_tool_call["function"]["arguments"],
|
||||
id=raw_tool_call.get("id"),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
@@ -236,7 +272,7 @@ def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]
|
||||
else:
|
||||
function_args = tool_call["function"]["arguments"]
|
||||
function_name = tool_call["function"]["name"]
|
||||
parsed = ToolCallChunk(
|
||||
parsed = tool_call_chunk(
|
||||
name=function_name,
|
||||
args=function_args,
|
||||
id=tool_call.get("id"),
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
@@ -40,6 +41,7 @@ if TYPE_CHECKING:
|
||||
from langchain_text_splitters import TextSplitter
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
AnyMessage = Union[
|
||||
@@ -221,8 +223,8 @@ def _create_message_from_message_type(
|
||||
elif message_type == "function":
|
||||
message = FunctionMessage(content=content, **kwargs)
|
||||
elif message_type == "tool":
|
||||
raw_output = kwargs.get("additional_kwargs", {}).pop("raw_output", None)
|
||||
message = ToolMessage(content=content, raw_output=raw_output, **kwargs)
|
||||
artifact = kwargs.get("additional_kwargs", {}).pop("artifact", None)
|
||||
message = ToolMessage(content=content, artifact=artifact, **kwargs)
|
||||
elif message_type == "remove":
|
||||
message = RemoveMessage(**kwargs)
|
||||
else:
|
||||
@@ -284,7 +286,7 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
|
||||
|
||||
|
||||
def convert_to_messages(
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert a sequence of messages to a list of messages.
|
||||
|
||||
@@ -294,6 +296,11 @@ def convert_to_messages(
|
||||
Returns:
|
||||
List of messages (BaseMessages).
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
if isinstance(messages, PromptValue):
|
||||
return messages.to_messages()
|
||||
return [_convert_to_message(m) for m in messages]
|
||||
|
||||
|
||||
@@ -329,7 +336,7 @@ def _runnable_support(func: Callable) -> Callable:
|
||||
|
||||
@_runnable_support
|
||||
def filter_messages(
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||
*,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
@@ -417,7 +424,7 @@ def filter_messages(
|
||||
|
||||
@_runnable_support
|
||||
def merge_message_runs(
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||
) -> List[BaseMessage]:
|
||||
"""Merge consecutive Messages of the same type.
|
||||
|
||||
@@ -451,12 +458,12 @@ def merge_message_runs(
|
||||
HumanMessage("wait your favorite food", id="bar",),
|
||||
AIMessage(
|
||||
"my favorite colo",
|
||||
tool_calls=[ToolCall(name="blah_tool", args={"x": 2}, id="123")],
|
||||
tool_calls=[ToolCall(name="blah_tool", args={"x": 2}, id="123", type="tool_call")],
|
||||
id="baz",
|
||||
),
|
||||
AIMessage(
|
||||
[{"type": "text", "text": "my favorite dish is lasagna"}],
|
||||
tool_calls=[ToolCall(name="blah_tool", args={"x": -10}, id="456")],
|
||||
tool_calls=[ToolCall(name="blah_tool", args={"x": -10}, id="456", type="tool_call")],
|
||||
id="blur",
|
||||
),
|
||||
]
|
||||
@@ -474,8 +481,8 @@ def merge_message_runs(
|
||||
{"type": "text", "text": "my favorite dish is lasagna"}
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCall({"name": "blah_tool", "args": {"x": 2}, "id": "123"),
|
||||
ToolCall({"name": "blah_tool", "args": {"x": -10}, "id": "456")
|
||||
ToolCall({"name": "blah_tool", "args": {"x": 2}, "id": "123", "type": "tool_call"}),
|
||||
ToolCall({"name": "blah_tool", "args": {"x": -10}, "id": "456", "type": "tool_call"})
|
||||
]
|
||||
id="baz"
|
||||
),
|
||||
@@ -506,7 +513,7 @@ def merge_message_runs(
|
||||
|
||||
@_runnable_support
|
||||
def trim_messages(
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||
*,
|
||||
max_tokens: int,
|
||||
token_counter: Union[
|
||||
|
||||
@@ -5,6 +5,12 @@ from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
from langchain_core.messages.tool import (
|
||||
invalid_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call as create_tool_call,
|
||||
)
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
@@ -59,6 +65,7 @@ def parse_tool_call(
|
||||
}
|
||||
if return_id:
|
||||
parsed["id"] = raw_tool_call.get("id")
|
||||
parsed = create_tool_call(**parsed) # type: ignore
|
||||
return parsed
|
||||
|
||||
|
||||
@@ -75,7 +82,7 @@ def make_invalid_tool_call(
|
||||
Returns:
|
||||
An InvalidToolCall instance with the error message.
|
||||
"""
|
||||
return InvalidToolCall(
|
||||
return invalid_tool_call(
|
||||
name=raw_tool_call["function"]["name"],
|
||||
args=raw_tool_call["function"]["arguments"],
|
||||
id=raw_tool_call.get("id"),
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -48,6 +48,10 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
|
||||
If no condition evaluates to True, the default branch is run on the input.
|
||||
|
||||
Parameters:
|
||||
branches: A list of (condition, Runnable) pairs.
|
||||
default: A Runnable to run if no condition is met.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -82,7 +86,18 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
RunnableLike, # To accommodate the default branch
|
||||
],
|
||||
) -> None:
|
||||
"""A Runnable that runs one of two branches based on a condition."""
|
||||
"""A Runnable that runs one of two branches based on a condition.
|
||||
|
||||
Args:
|
||||
*branches: A list of (condition, Runnable) pairs.
|
||||
Defaults a Runnable to run if no condition is met.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of branches is less than 2.
|
||||
TypeError: If the default branch is not Runnable, Callable or Mapping.
|
||||
TypeError: If a branch is not a tuple or list.
|
||||
ValueError: If a branch is not of length 2.
|
||||
"""
|
||||
if len(branches) < 2:
|
||||
raise ValueError("RunnableBranch requires at least two branches")
|
||||
|
||||
@@ -93,7 +108,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
(Runnable, Callable, Mapping), # type: ignore[arg-type]
|
||||
):
|
||||
raise TypeError(
|
||||
"RunnableBranch default must be runnable, callable or mapping."
|
||||
"RunnableBranch default must be Runnable, callable or mapping."
|
||||
)
|
||||
|
||||
default_ = cast(
|
||||
@@ -176,7 +191,19 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""First evaluates the condition, then delegate to true or false branch."""
|
||||
"""First evaluates the condition, then delegate to true or false branch.
|
||||
|
||||
Args:
|
||||
input: The input to the Runnable.
|
||||
config: The configuration for the Runnable. Defaults to None.
|
||||
**kwargs: Additional keyword arguments to pass to the Runnable.
|
||||
|
||||
Returns:
|
||||
The output of the branch that was run.
|
||||
|
||||
Raises:
|
||||
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
@@ -277,7 +304,19 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
"""First evaluates the condition,
|
||||
then delegate to true or false branch."""
|
||||
then delegate to true or false branch.
|
||||
|
||||
Args:
|
||||
input: The input to the Runnable.
|
||||
config: The configuration for the Runnable. Defaults to None.
|
||||
**kwargs: Additional keyword arguments to pass to the Runnable.
|
||||
|
||||
Yields:
|
||||
The output of the branch that was run.
|
||||
|
||||
Raises:
|
||||
BaseException: If an error occurs during the execution of the Runnable.
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
@@ -352,7 +391,19 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
"""First evaluates the condition,
|
||||
then delegate to true or false branch."""
|
||||
then delegate to true or false branch.
|
||||
|
||||
Args:
|
||||
input: The input to the Runnable.
|
||||
config: The configuration for the Runnable. Defaults to None.
|
||||
**kwargs: Additional keyword arguments to pass to the Runnable.
|
||||
|
||||
Yields:
|
||||
The output of the branch that was run.
|
||||
|
||||
Raises:
|
||||
BaseException: If an error occurs during the execution of the Runnable.
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
|
||||
@@ -111,7 +111,7 @@ var_child_runnable_config = ContextVar(
|
||||
|
||||
|
||||
def _set_config_context(config: RunnableConfig) -> None:
|
||||
"""Set the child runnable config + tracing context
|
||||
"""Set the child Runnable config + tracing context
|
||||
|
||||
Args:
|
||||
config (RunnableConfig): The config to set.
|
||||
@@ -216,7 +216,6 @@ def patch_config(
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]): The config to patch.
|
||||
copy_locals (bool, optional): Whether to copy locals. Defaults to False.
|
||||
callbacks (Optional[BaseCallbackManager], optional): The callbacks to set.
|
||||
Defaults to None.
|
||||
recursion_limit (Optional[int], optional): The recursion limit to set.
|
||||
@@ -362,9 +361,9 @@ def call_func_with_variable_args(
|
||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output]]):
|
||||
The function to call.
|
||||
input (Input): The input to the function.
|
||||
run_manager (CallbackManagerForChainRun): The run manager to
|
||||
pass to the function.
|
||||
config (RunnableConfig): The config to pass to the function.
|
||||
run_manager (CallbackManagerForChainRun): The run manager to
|
||||
pass to the function. Defaults to None.
|
||||
**kwargs (Any): The keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
@@ -395,7 +394,7 @@ def acall_func_with_variable_args(
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[Output]:
|
||||
"""Call function that may optionally accept a run_manager and/or config.
|
||||
"""Async call function that may optionally accept a run_manager and/or config.
|
||||
|
||||
Args:
|
||||
func (Union[Callable[[Input], Awaitable[Output]], Callable[[Input,
|
||||
@@ -403,9 +402,9 @@ def acall_func_with_variable_args(
|
||||
AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]]]):
|
||||
The function to call.
|
||||
input (Input): The input to the function.
|
||||
run_manager (AsyncCallbackManagerForChainRun): The run manager
|
||||
to pass to the function.
|
||||
config (RunnableConfig): The config to pass to the function.
|
||||
run_manager (AsyncCallbackManagerForChainRun): The run manager
|
||||
to pass to the function. Defaults to None.
|
||||
**kwargs (Any): The keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
@@ -493,6 +492,18 @@ class ContextThreadPoolExecutor(ThreadPoolExecutor):
|
||||
timeout: float | None = None,
|
||||
chunksize: int = 1,
|
||||
) -> Iterator[T]:
|
||||
"""Map a function to multiple iterables.
|
||||
|
||||
Args:
|
||||
fn (Callable[..., T]): The function to map.
|
||||
*iterables (Iterable[Any]): The iterables to map over.
|
||||
timeout (float | None, optional): The timeout for the map.
|
||||
Defaults to None.
|
||||
chunksize (int, optional): The chunksize for the map. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
Iterator[T]: The iterator for the mapped function.
|
||||
"""
|
||||
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
|
||||
|
||||
def _wrapped_fn(*args: Any) -> T:
|
||||
@@ -534,13 +545,16 @@ async def run_in_executor(
|
||||
"""Run a function in an executor.
|
||||
|
||||
Args:
|
||||
executor (Executor): The executor.
|
||||
executor_or_config: The executor or config to run in.
|
||||
func (Callable[P, Output]): The function.
|
||||
*args (Any): The positional arguments to the function.
|
||||
**kwargs (Any): The keyword arguments to the function.
|
||||
|
||||
Returns:
|
||||
Output: The output of the function.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the function raises a StopIteration.
|
||||
"""
|
||||
|
||||
def wrapper() -> T:
|
||||
|
||||
@@ -44,7 +44,15 @@ from langchain_core.runnables.utils import (
|
||||
|
||||
|
||||
class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
"""Serializable Runnable that can be dynamically configured."""
|
||||
"""Serializable Runnable that can be dynamically configured.
|
||||
|
||||
A DynamicRunnable should be initiated using the `configurable_fields` or
|
||||
`configurable_alternatives` method of a Runnable.
|
||||
|
||||
Parameters:
|
||||
default: The default Runnable to use.
|
||||
config: The configuration to use.
|
||||
"""
|
||||
|
||||
default: RunnableSerializable[Input, Output]
|
||||
|
||||
@@ -99,6 +107,15 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
||||
def prepare(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
|
||||
"""Prepare the Runnable for invocation.
|
||||
|
||||
Args:
|
||||
config: The configuration to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[Runnable[Input, Output], RunnableConfig]: The prepared Runnable and
|
||||
configuration.
|
||||
"""
|
||||
runnable: Runnable[Input, Output] = self
|
||||
while isinstance(runnable, DynamicRunnable):
|
||||
runnable, config = runnable._prepare(merge_configs(runnable.config, config))
|
||||
@@ -284,6 +301,9 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
A RunnableConfigurableFields should be initiated using the
|
||||
`configurable_fields` method of a Runnable.
|
||||
|
||||
Parameters:
|
||||
fields: The configurable fields to use.
|
||||
|
||||
Here is an example of using a RunnableConfigurableFields with LLMs:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -348,6 +368,11 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
"""Get the configuration specs for the RunnableConfigurableFields.
|
||||
|
||||
Returns:
|
||||
List[ConfigurableFieldSpec]: The configuration specs.
|
||||
"""
|
||||
return get_unique_config_specs(
|
||||
[
|
||||
(
|
||||
@@ -374,6 +399,8 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
||||
def configurable_fields(
|
||||
self, **kwargs: AnyConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
"""Get a new RunnableConfigurableFields with the specified
|
||||
configurable fields."""
|
||||
return self.default.configurable_fields(**{**self.fields, **kwargs})
|
||||
|
||||
def _prepare(
|
||||
@@ -493,11 +520,13 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
""" # noqa: E501
|
||||
|
||||
which: ConfigurableField
|
||||
"""The ConfigurableField to use to choose between alternatives."""
|
||||
|
||||
alternatives: Dict[
|
||||
str,
|
||||
Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
|
||||
]
|
||||
"""The alternatives to choose from."""
|
||||
|
||||
default_key: str = "default"
|
||||
"""The enum value to use for the default option. Defaults to "default"."""
|
||||
@@ -619,7 +648,7 @@ def prefix_config_spec(
|
||||
prefix: The prefix to add.
|
||||
|
||||
Returns:
|
||||
|
||||
ConfigurableFieldSpec: The prefixed ConfigurableFieldSpec.
|
||||
"""
|
||||
return (
|
||||
ConfigurableFieldSpec(
|
||||
@@ -641,6 +670,13 @@ def make_options_spec(
|
||||
) -> ConfigurableFieldSpec:
|
||||
"""Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or
|
||||
ConfigurableFieldMultiOption.
|
||||
|
||||
Args:
|
||||
spec: The ConfigurableFieldSingleOption or ConfigurableFieldMultiOption.
|
||||
description: The description to use if the spec does not have one.
|
||||
|
||||
Returns:
|
||||
The ConfigurableFieldSpec.
|
||||
"""
|
||||
with _enums_for_spec_lock:
|
||||
if enum := _enums_for_spec.get(spec):
|
||||
|
||||
@@ -91,7 +91,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
"""
|
||||
|
||||
runnable: Runnable[Input, Output]
|
||||
"""The runnable to run first."""
|
||||
"""The Runnable to run first."""
|
||||
fallbacks: Sequence[Runnable[Input, Output]]
|
||||
"""A sequence of fallbacks to try."""
|
||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
|
||||
@@ -102,7 +102,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
exception_key: Optional[str] = None
|
||||
"""If string is specified then handled exceptions will be passed to fallbacks as
|
||||
part of the input under the specified key. If None, exceptions
|
||||
will not be passed to fallbacks. If used, the base runnable and its fallbacks
|
||||
will not be passed to fallbacks. If used, the base Runnable and its fallbacks
|
||||
must accept a dictionary as input."""
|
||||
|
||||
class Config:
|
||||
@@ -554,7 +554,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
await run_manager.on_chain_end(output)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Get an attribute from the wrapped runnable and its fallbacks.
|
||||
"""Get an attribute from the wrapped Runnable and its fallbacks.
|
||||
|
||||
Returns:
|
||||
If the attribute is anything other than a method that outputs a Runnable,
|
||||
|
||||
@@ -57,7 +57,14 @@ def is_uuid(value: str) -> bool:
|
||||
|
||||
|
||||
class Edge(NamedTuple):
|
||||
"""Edge in a graph."""
|
||||
"""Edge in a graph.
|
||||
|
||||
Parameters:
|
||||
source: The source node id.
|
||||
target: The target node id.
|
||||
data: Optional data associated with the edge. Defaults to None.
|
||||
conditional: Whether the edge is conditional. Defaults to False.
|
||||
"""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
@@ -67,6 +74,15 @@ class Edge(NamedTuple):
|
||||
def copy(
|
||||
self, *, source: Optional[str] = None, target: Optional[str] = None
|
||||
) -> Edge:
|
||||
"""Return a copy of the edge with optional new source and target nodes.
|
||||
|
||||
Args:
|
||||
source: The new source node id. Defaults to None.
|
||||
target: The new target node id. Defaults to None.
|
||||
|
||||
Returns:
|
||||
A copy of the edge with the new source and target nodes.
|
||||
"""
|
||||
return Edge(
|
||||
source=source or self.source,
|
||||
target=target or self.target,
|
||||
@@ -76,7 +92,14 @@ class Edge(NamedTuple):
|
||||
|
||||
|
||||
class Node(NamedTuple):
|
||||
"""Node in a graph."""
|
||||
"""Node in a graph.
|
||||
|
||||
Parameters:
|
||||
id: The unique identifier of the node.
|
||||
name: The name of the node.
|
||||
data: The data of the node.
|
||||
metadata: Optional metadata for the node. Defaults to None.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
@@ -84,6 +107,15 @@ class Node(NamedTuple):
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
|
||||
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
|
||||
"""Return a copy of the node with optional new id and name.
|
||||
|
||||
Args:
|
||||
id: The new node id. Defaults to None.
|
||||
name: The new node name. Defaults to None.
|
||||
|
||||
Returns:
|
||||
A copy of the node with the new id and name.
|
||||
"""
|
||||
return Node(
|
||||
id=id or self.id,
|
||||
name=name or self.name,
|
||||
@@ -93,7 +125,13 @@ class Node(NamedTuple):
|
||||
|
||||
|
||||
class Branch(NamedTuple):
|
||||
"""Branch in a graph."""
|
||||
"""Branch in a graph.
|
||||
|
||||
Parameters:
|
||||
condition: A callable that returns a string representation of the condition.
|
||||
ends: Optional dictionary of end node ids for the branches. Defaults
|
||||
to None.
|
||||
"""
|
||||
|
||||
condition: Callable[..., str]
|
||||
ends: Optional[dict[str, str]]
|
||||
@@ -117,12 +155,18 @@ class CurveStyle(Enum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeColors:
|
||||
"""Schema for Hexadecimal color codes for different node types"""
|
||||
class NodeStyles:
|
||||
"""Schema for Hexadecimal color codes for different node types.
|
||||
|
||||
start: str = "#ffdfba"
|
||||
end: str = "#baffc9"
|
||||
other: str = "#fad7de"
|
||||
Parameters:
|
||||
default: The default color code. Defaults to "fill:#f2f0ff,line-height:1.2".
|
||||
first: The color code for the first node. Defaults to "fill-opacity:0".
|
||||
last: The color code for the last node. Defaults to "fill:#bfb6fc".
|
||||
"""
|
||||
|
||||
default: str = "fill:#f2f0ff,line-height:1.2"
|
||||
first: str = "fill-opacity:0"
|
||||
last: str = "fill:#bfb6fc"
|
||||
|
||||
|
||||
class MermaidDrawMethod(Enum):
|
||||
@@ -161,7 +205,7 @@ def node_data_json(
|
||||
Args:
|
||||
node: The node to convert.
|
||||
with_schemas: Whether to include the schema of the data if
|
||||
it is a Pydantic model.
|
||||
it is a Pydantic model. Defaults to False.
|
||||
|
||||
Returns:
|
||||
A dictionary with the type of the data and the data itself.
|
||||
@@ -209,13 +253,26 @@ def node_data_json(
|
||||
|
||||
@dataclass
|
||||
class Graph:
|
||||
"""Graph of nodes and edges."""
|
||||
"""Graph of nodes and edges.
|
||||
|
||||
Parameters:
|
||||
nodes: Dictionary of nodes in the graph. Defaults to an empty dictionary.
|
||||
edges: List of edges in the graph. Defaults to an empty list.
|
||||
"""
|
||||
|
||||
nodes: Dict[str, Node] = field(default_factory=dict)
|
||||
edges: List[Edge] = field(default_factory=list)
|
||||
|
||||
def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Convert the graph to a JSON-serializable format."""
|
||||
"""Convert the graph to a JSON-serializable format.
|
||||
|
||||
Args:
|
||||
with_schemas: Whether to include the schemas of the nodes if they are
|
||||
Pydantic models. Defaults to False.
|
||||
|
||||
Returns:
|
||||
A dictionary with the nodes and edges of the graph.
|
||||
"""
|
||||
stable_node_ids = {
|
||||
node.id: i if is_uuid(node.id) else node.id
|
||||
for i, node in enumerate(self.nodes.values())
|
||||
@@ -247,6 +304,8 @@ class Graph:
|
||||
return bool(self.nodes)
|
||||
|
||||
def next_id(self) -> str:
|
||||
"""Return a new unique node
|
||||
identifier that can be used to add a node to the graph."""
|
||||
return uuid4().hex
|
||||
|
||||
def add_node(
|
||||
@@ -256,7 +315,19 @@ class Graph:
|
||||
*,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Node:
|
||||
"""Add a node to the graph and return it."""
|
||||
"""Add a node to the graph and return it.
|
||||
|
||||
Args:
|
||||
data: The data of the node.
|
||||
id: The id of the node. Defaults to None.
|
||||
metadata: Optional metadata for the node. Defaults to None.
|
||||
|
||||
Returns:
|
||||
The node that was added to the graph.
|
||||
|
||||
Raises:
|
||||
ValueError: If a node with the same id already exists.
|
||||
"""
|
||||
if id is not None and id in self.nodes:
|
||||
raise ValueError(f"Node with id {id} already exists")
|
||||
id = id or self.next_id()
|
||||
@@ -265,7 +336,11 @@ class Graph:
|
||||
return node
|
||||
|
||||
def remove_node(self, node: Node) -> None:
|
||||
"""Remove a node from the graph and all edges connected to it."""
|
||||
"""Remove a node from the graph and all edges connected to it.
|
||||
|
||||
Args:
|
||||
node: The node to remove.
|
||||
"""
|
||||
self.nodes.pop(node.id)
|
||||
self.edges = [
|
||||
edge
|
||||
@@ -280,7 +355,20 @@ class Graph:
|
||||
data: Optional[Stringifiable] = None,
|
||||
conditional: bool = False,
|
||||
) -> Edge:
|
||||
"""Add an edge to the graph and return it."""
|
||||
"""Add an edge to the graph and return it.
|
||||
|
||||
Args:
|
||||
source: The source node of the edge.
|
||||
target: The target node of the edge.
|
||||
data: Optional data associated with the edge. Defaults to None.
|
||||
conditional: Whether the edge is conditional. Defaults to False.
|
||||
|
||||
Returns:
|
||||
The edge that was added to the graph.
|
||||
|
||||
Raises:
|
||||
ValueError: If the source or target node is not in the graph.
|
||||
"""
|
||||
if source.id not in self.nodes:
|
||||
raise ValueError(f"Source node {source.id} not in graph")
|
||||
if target.id not in self.nodes:
|
||||
@@ -295,7 +383,15 @@ class Graph:
|
||||
self, graph: Graph, *, prefix: str = ""
|
||||
) -> Tuple[Optional[Node], Optional[Node]]:
|
||||
"""Add all nodes and edges from another graph.
|
||||
Note this doesn't check for duplicates, nor does it connect the graphs."""
|
||||
Note this doesn't check for duplicates, nor does it connect the graphs.
|
||||
|
||||
Args:
|
||||
graph: The graph to add.
|
||||
prefix: The prefix to add to the node ids. Defaults to "".
|
||||
|
||||
Returns:
|
||||
A tuple of the first and last nodes of the subgraph.
|
||||
"""
|
||||
if all(is_uuid(node.id) for node in graph.nodes.values()):
|
||||
prefix = ""
|
||||
|
||||
@@ -350,7 +446,7 @@ class Graph:
|
||||
def first_node(self) -> Optional[Node]:
|
||||
"""Find the single node that is not a target of any edge.
|
||||
If there is no such node, or there are multiple, return None.
|
||||
When drawing the graph this node would be the origin."""
|
||||
When drawing the graph, this node would be the origin."""
|
||||
targets = {edge.target for edge in self.edges}
|
||||
found: List[Node] = []
|
||||
for node in self.nodes.values():
|
||||
@@ -361,7 +457,7 @@ class Graph:
|
||||
def last_node(self) -> Optional[Node]:
|
||||
"""Find the single node that is not a source of any edge.
|
||||
If there is no such node, or there are multiple, return None.
|
||||
When drawing the graph this node would be the destination.
|
||||
When drawing the graph, this node would be the destination.
|
||||
"""
|
||||
sources = {edge.source for edge in self.edges}
|
||||
found: List[Node] = []
|
||||
@@ -372,7 +468,7 @@ class Graph:
|
||||
|
||||
def trim_first_node(self) -> None:
|
||||
"""Remove the first node if it exists and has a single outgoing edge,
|
||||
ie. if removing it would not leave the graph without a "first" node."""
|
||||
i.e., if removing it would not leave the graph without a "first" node."""
|
||||
first_node = self.first_node()
|
||||
if first_node:
|
||||
if (
|
||||
@@ -384,7 +480,7 @@ class Graph:
|
||||
|
||||
def trim_last_node(self) -> None:
|
||||
"""Remove the last node if it exists and has a single incoming edge,
|
||||
ie. if removing it would not leave the graph without a "last" node."""
|
||||
i.e., if removing it would not leave the graph without a "last" node."""
|
||||
last_node = self.last_node()
|
||||
if last_node:
|
||||
if (
|
||||
@@ -395,6 +491,7 @@ class Graph:
|
||||
self.remove_node(last_node)
|
||||
|
||||
def draw_ascii(self) -> str:
|
||||
"""Draw the graph as an ASCII art string."""
|
||||
from langchain_core.runnables.graph_ascii import draw_ascii
|
||||
|
||||
return draw_ascii(
|
||||
@@ -403,6 +500,7 @@ class Graph:
|
||||
)
|
||||
|
||||
def print_ascii(self) -> None:
|
||||
"""Print the graph as an ASCII art string."""
|
||||
print(self.draw_ascii()) # noqa: T201
|
||||
|
||||
@overload
|
||||
@@ -427,6 +525,17 @@ class Graph:
|
||||
fontname: Optional[str] = None,
|
||||
labels: Optional[LabelsDict] = None,
|
||||
) -> Union[bytes, None]:
|
||||
"""Draw the graph as a PNG image.
|
||||
|
||||
Args:
|
||||
output_file_path: The path to save the image to. If None, the image
|
||||
is not saved. Defaults to None.
|
||||
fontname: The name of the font to use. Defaults to None.
|
||||
labels: Optional labels for nodes and edges in the graph. Defaults to None.
|
||||
|
||||
Returns:
|
||||
The PNG image as bytes if output_file_path is None, None otherwise.
|
||||
"""
|
||||
from langchain_core.runnables.graph_png import PngDrawer
|
||||
|
||||
default_node_labels = {node.id: node.name for node in self.nodes.values()}
|
||||
@@ -447,11 +556,21 @@ class Graph:
|
||||
*,
|
||||
with_styles: bool = True,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_colors: NodeColors = NodeColors(
|
||||
start="#ffdfba", end="#baffc9", other="#fad7de"
|
||||
),
|
||||
node_colors: NodeStyles = NodeStyles(),
|
||||
wrap_label_n_words: int = 9,
|
||||
) -> str:
|
||||
"""Draw the graph as a Mermaid syntax string.
|
||||
|
||||
Args:
|
||||
with_styles: Whether to include styles in the syntax. Defaults to True.
|
||||
curve_style: The style of the edges. Defaults to CurveStyle.LINEAR.
|
||||
node_colors: The colors of the nodes. Defaults to NodeStyles().
|
||||
wrap_label_n_words: The number of words to wrap the node labels at.
|
||||
Defaults to 9.
|
||||
|
||||
Returns:
|
||||
The Mermaid syntax string.
|
||||
"""
|
||||
from langchain_core.runnables.graph_mermaid import draw_mermaid
|
||||
|
||||
graph = self.reid()
|
||||
@@ -465,7 +584,7 @@ class Graph:
|
||||
last_node=last_node.id if last_node else None,
|
||||
with_styles=with_styles,
|
||||
curve_style=curve_style,
|
||||
node_colors=node_colors,
|
||||
node_styles=node_colors,
|
||||
wrap_label_n_words=wrap_label_n_words,
|
||||
)
|
||||
|
||||
@@ -473,15 +592,30 @@ class Graph:
|
||||
self,
|
||||
*,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_colors: NodeColors = NodeColors(
|
||||
start="#ffdfba", end="#baffc9", other="#fad7de"
|
||||
),
|
||||
node_colors: NodeStyles = NodeStyles(),
|
||||
wrap_label_n_words: int = 9,
|
||||
output_file_path: Optional[str] = None,
|
||||
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
|
||||
background_color: str = "white",
|
||||
padding: int = 10,
|
||||
) -> bytes:
|
||||
"""Draw the graph as a PNG image using Mermaid.
|
||||
|
||||
Args:
|
||||
curve_style: The style of the edges. Defaults to CurveStyle.LINEAR.
|
||||
node_colors: The colors of the nodes. Defaults to NodeStyles().
|
||||
wrap_label_n_words: The number of words to wrap the node labels at.
|
||||
Defaults to 9.
|
||||
output_file_path: The path to save the image to. If None, the image
|
||||
is not saved. Defaults to None.
|
||||
draw_method: The method to use to draw the graph.
|
||||
Defaults to MermaidDrawMethod.API.
|
||||
background_color: The color of the background. Defaults to "white".
|
||||
padding: The padding around the graph. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
The PNG image as bytes.
|
||||
"""
|
||||
from langchain_core.runnables.graph_mermaid import draw_mermaid_png
|
||||
|
||||
mermaid_syntax = self.draw_mermaid(
|
||||
|
||||
@@ -17,6 +17,7 @@ class VertexViewer:
|
||||
"""
|
||||
|
||||
HEIGHT = 3 # top and bottom box edges + text
|
||||
"""Height of the box."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self._h = self.HEIGHT # top and bottom box edges + text
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain_core.runnables.graph import (
|
||||
Edge,
|
||||
MermaidDrawMethod,
|
||||
Node,
|
||||
NodeColors,
|
||||
NodeStyles,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,21 +20,28 @@ def draw_mermaid(
|
||||
last_node: Optional[str] = None,
|
||||
with_styles: bool = True,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_colors: NodeColors = NodeColors(),
|
||||
node_styles: NodeStyles = NodeStyles(),
|
||||
wrap_label_n_words: int = 9,
|
||||
) -> str:
|
||||
"""Draws a Mermaid graph using the provided graph data
|
||||
"""Draws a Mermaid graph using the provided graph data.
|
||||
|
||||
Args:
|
||||
nodes (dict[str, str]): List of node ids
|
||||
edges (List[Edge]): List of edges, object with source,
|
||||
target and data.
|
||||
nodes (dict[str, str]): List of node ids.
|
||||
edges (List[Edge]): List of edges, object with a source,
|
||||
target and data.
|
||||
first_node (str, optional): Id of the first node. Defaults to None.
|
||||
last_node (str, optional): Id of the last node. Defaults to None.
|
||||
with_styles (bool, optional): Whether to include styles in the graph.
|
||||
Defaults to True.
|
||||
curve_style (CurveStyle, optional): Curve style for the edges.
|
||||
node_colors (NodeColors, optional): Node colors for different types.
|
||||
Defaults to CurveStyle.LINEAR.
|
||||
node_styles (NodeStyles, optional): Node colors for different types.
|
||||
Defaults to NodeStyles().
|
||||
wrap_label_n_words (int, optional): Words to wrap the edge labels.
|
||||
Defaults to 9.
|
||||
|
||||
Returns:
|
||||
str: Mermaid graph syntax
|
||||
str: Mermaid graph syntax.
|
||||
"""
|
||||
# Initialize Mermaid graph configuration
|
||||
mermaid_graph = (
|
||||
@@ -49,23 +56,27 @@ def draw_mermaid(
|
||||
if with_styles:
|
||||
# Node formatting templates
|
||||
default_class_label = "default"
|
||||
format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
|
||||
format_dict = {default_class_label: "{0}({1})"}
|
||||
if first_node is not None:
|
||||
format_dict[first_node] = "{0}[{0}]:::startclass"
|
||||
format_dict[first_node] = "{0}([{0}]):::first"
|
||||
if last_node is not None:
|
||||
format_dict[last_node] = "{0}[{0}]:::endclass"
|
||||
format_dict[last_node] = "{0}([{0}]):::last"
|
||||
|
||||
# Add nodes to the graph
|
||||
for key, node in nodes.items():
|
||||
label = node.name.split(":")[-1]
|
||||
if node.metadata:
|
||||
label = f"<strong>{label}</strong>\n" + "\n".join(
|
||||
f"{key} = {value}" for key, value in node.metadata.items()
|
||||
label = (
|
||||
f"{label}<hr/><small><em>"
|
||||
+ "\n".join(
|
||||
f"{key} = {value}" for key, value in node.metadata.items()
|
||||
)
|
||||
+ "</em></small>"
|
||||
)
|
||||
node_label = format_dict.get(key, format_dict[default_class_label]).format(
|
||||
_escape_node_label(key), label
|
||||
)
|
||||
mermaid_graph += f"\t{node_label};\n"
|
||||
mermaid_graph += f"\t{node_label}\n"
|
||||
|
||||
subgraph = ""
|
||||
# Add edges to the graph
|
||||
@@ -89,16 +100,14 @@ def draw_mermaid(
|
||||
words = str(edge_data).split() # Split the string into words
|
||||
# Group words into chunks of wrap_label_n_words size
|
||||
if len(words) > wrap_label_n_words:
|
||||
edge_data = "<br>".join(
|
||||
[
|
||||
" ".join(words[i : i + wrap_label_n_words])
|
||||
for i in range(0, len(words), wrap_label_n_words)
|
||||
]
|
||||
edge_data = " <br> ".join(
|
||||
" ".join(words[i : i + wrap_label_n_words])
|
||||
for i in range(0, len(words), wrap_label_n_words)
|
||||
)
|
||||
if edge.conditional:
|
||||
edge_label = f" -. {edge_data} .-> "
|
||||
edge_label = f" -.  {edge_data}  .-> "
|
||||
else:
|
||||
edge_label = f" -- {edge_data} --> "
|
||||
edge_label = f" --  {edge_data}  --> "
|
||||
else:
|
||||
if edge.conditional:
|
||||
edge_label = " -.-> "
|
||||
@@ -113,7 +122,7 @@ def draw_mermaid(
|
||||
|
||||
# Add custom styles for nodes
|
||||
if with_styles:
|
||||
mermaid_graph += _generate_mermaid_graph_styles(node_colors)
|
||||
mermaid_graph += _generate_mermaid_graph_styles(node_styles)
|
||||
return mermaid_graph
|
||||
|
||||
|
||||
@@ -122,11 +131,11 @@ def _escape_node_label(node_label: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
|
||||
|
||||
|
||||
def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str:
|
||||
def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:
|
||||
"""Generates Mermaid graph styles for different node types."""
|
||||
styles = ""
|
||||
for class_name, color in asdict(node_colors).items():
|
||||
styles += f"\tclassDef {class_name}class fill:{color};\n"
|
||||
for class_name, style in asdict(node_colors).items():
|
||||
styles += f"\tclassDef {class_name} {style}\n"
|
||||
return styles
|
||||
|
||||
|
||||
@@ -137,7 +146,24 @@ def draw_mermaid_png(
|
||||
background_color: Optional[str] = "white",
|
||||
padding: int = 10,
|
||||
) -> bytes:
|
||||
"""Draws a Mermaid graph as PNG using provided syntax."""
|
||||
"""Draws a Mermaid graph as PNG using provided syntax.
|
||||
|
||||
Args:
|
||||
mermaid_syntax (str): Mermaid graph syntax.
|
||||
output_file_path (str, optional): Path to save the PNG image.
|
||||
Defaults to None.
|
||||
draw_method (MermaidDrawMethod, optional): Method to draw the graph.
|
||||
Defaults to MermaidDrawMethod.API.
|
||||
background_color (str, optional): Background color of the image.
|
||||
Defaults to "white".
|
||||
padding (int, optional): Padding around the image. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
bytes: PNG image bytes.
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid draw method is provided.
|
||||
"""
|
||||
if draw_method == MermaidDrawMethod.PYPPETEER:
|
||||
import asyncio
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain_core.runnables.graph import Graph, LabelsDict
|
||||
class PngDrawer:
|
||||
"""Helper class to draw a state graph into a PNG file.
|
||||
|
||||
It requires graphviz and pygraphviz to be installed.
|
||||
It requires `graphviz` and `pygraphviz` to be installed.
|
||||
:param fontname: The font to use for the labels
|
||||
:param labels: A dictionary of label overrides. The dictionary
|
||||
should have the following format:
|
||||
@@ -33,7 +33,7 @@ class PngDrawer:
|
||||
"""Initializes the PNG drawer.
|
||||
|
||||
Args:
|
||||
fontname: The font to use for the labels
|
||||
fontname: The font to use for the labels. Defaults to "arial".
|
||||
labels: A dictionary of label overrides. The dictionary
|
||||
should have the following format:
|
||||
{
|
||||
@@ -48,6 +48,7 @@ class PngDrawer:
|
||||
}
|
||||
}
|
||||
The keys are the original labels, and the values are the new labels.
|
||||
Defaults to None.
|
||||
"""
|
||||
self.fontname = fontname or "arial"
|
||||
self.labels = labels or LabelsDict(nodes={}, edges={})
|
||||
@@ -56,7 +57,7 @@ class PngDrawer:
|
||||
"""Returns the label to use for a node.
|
||||
|
||||
Args:
|
||||
label: The original label
|
||||
label: The original label.
|
||||
|
||||
Returns:
|
||||
The new label.
|
||||
@@ -68,7 +69,7 @@ class PngDrawer:
|
||||
"""Returns the label to use for an edge.
|
||||
|
||||
Args:
|
||||
label: The original label
|
||||
label: The original label.
|
||||
|
||||
Returns:
|
||||
The new label.
|
||||
@@ -80,8 +81,8 @@ class PngDrawer:
|
||||
"""Adds a node to the graph.
|
||||
|
||||
Args:
|
||||
viz: The graphviz object
|
||||
node: The node to add
|
||||
viz: The graphviz object.
|
||||
node: The node to add.
|
||||
|
||||
Returns:
|
||||
None
|
||||
@@ -106,9 +107,9 @@ class PngDrawer:
|
||||
"""Adds an edge to the graph.
|
||||
|
||||
Args:
|
||||
viz: The graphviz object
|
||||
source: The source node
|
||||
target: The target node
|
||||
viz: The graphviz object.
|
||||
source: The source node.
|
||||
target: The target node.
|
||||
label: The label for the edge. Defaults to None.
|
||||
conditional: Whether the edge is conditional. Defaults to False.
|
||||
|
||||
@@ -127,7 +128,7 @@ class PngDrawer:
|
||||
def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
|
||||
"""Draw the given state graph into a PNG file.
|
||||
|
||||
Requires graphviz and pygraphviz to be installed.
|
||||
Requires `graphviz` and `pygraphviz` to be installed.
|
||||
:param graph: The graph to draw
|
||||
:param output_path: The path to save the PNG. If None, PNG bytes are returned.
|
||||
"""
|
||||
@@ -156,14 +157,32 @@ class PngDrawer:
|
||||
viz.close()
|
||||
|
||||
def add_nodes(self, viz: Any, graph: Graph) -> None:
|
||||
"""Add nodes to the graph.
|
||||
|
||||
Args:
|
||||
viz: The graphviz object.
|
||||
graph: The graph to draw.
|
||||
"""
|
||||
for node in graph.nodes:
|
||||
self.add_node(viz, node)
|
||||
|
||||
def add_edges(self, viz: Any, graph: Graph) -> None:
|
||||
"""Add edges to the graph.
|
||||
|
||||
Args:
|
||||
viz: The graphviz object.
|
||||
graph: The graph to draw.
|
||||
"""
|
||||
for start, end, data, cond in graph.edges:
|
||||
self.add_edge(viz, start, end, str(data), cond)
|
||||
|
||||
def update_styles(self, viz: Any, graph: Graph) -> None:
|
||||
"""Update the styles of the entrypoint and END nodes.
|
||||
|
||||
Args:
|
||||
viz: The graphviz object.
|
||||
graph: The graph to draw.
|
||||
"""
|
||||
if first := graph.first_node():
|
||||
viz.get_node(first.id).attr.update(fillcolor="lightblue")
|
||||
if last := graph.last_node():
|
||||
|
||||
@@ -45,13 +45,13 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
history for it; it is responsible for reading and updating the chat message
|
||||
history.
|
||||
|
||||
The formats supports for the inputs and outputs of the wrapped Runnable
|
||||
The formats supported for the inputs and outputs of the wrapped Runnable
|
||||
are described below.
|
||||
|
||||
RunnableWithMessageHistory must always be called with a config that contains
|
||||
the appropriate parameters for the chat message history factory.
|
||||
|
||||
By default the Runnable is expected to take a single configuration parameter
|
||||
By default, the Runnable is expected to take a single configuration parameter
|
||||
called `session_id` which is a string. This parameter is used to create a new
|
||||
or look up an existing chat message history that matches the given session_id.
|
||||
|
||||
@@ -70,6 +70,19 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
For production use cases, you will want to use a persistent implementation
|
||||
of chat message history, such as ``RedisChatMessageHistory``.
|
||||
|
||||
Parameters:
|
||||
get_session_history: Function that returns a new BaseChatMessageHistory.
|
||||
This function should either take a single positional argument
|
||||
`session_id` of type string and return a corresponding
|
||||
chat message history instance.
|
||||
input_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input. The key in the input dict that contains the messages.
|
||||
output_messages_key: Must be specified if the base Runnable returns a dict
|
||||
as output. The key in the output dict that contains the messages.
|
||||
history_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input and expects a separate key for historical messages.
|
||||
history_factory_config: Configure fields that should be passed to the
|
||||
chat history factory. See ``ConfigurableFieldSpec`` for more details.
|
||||
|
||||
Example: Chat message history with an in-memory implementation for testing.
|
||||
|
||||
@@ -287,9 +300,9 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
...
|
||||
|
||||
input_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input.
|
||||
as input. Default is None.
|
||||
output_messages_key: Must be specified if the base runnable returns a dict
|
||||
as output.
|
||||
as output. Default is None.
|
||||
history_messages_key: Must be specified if the base runnable accepts a dict
|
||||
as input and expects a separate key for historical messages.
|
||||
history_factory_config: Configure fields that should be passed to the
|
||||
@@ -347,6 +360,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
"""Get the configuration specs for the RunnableWithMessageHistory."""
|
||||
return get_unique_config_specs(
|
||||
super().config_specs + list(self.history_factory_config)
|
||||
)
|
||||
|
||||
@@ -53,19 +53,33 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def identity(x: Other) -> Other:
|
||||
"""Identity function"""
|
||||
"""Identity function.
|
||||
|
||||
Args:
|
||||
x (Other): input.
|
||||
|
||||
Returns:
|
||||
Other: output.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
async def aidentity(x: Other) -> Other:
|
||||
"""Async identity function"""
|
||||
"""Async identity function.
|
||||
|
||||
Args:
|
||||
x (Other): input.
|
||||
|
||||
Returns:
|
||||
Other: output.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
"""Runnable to passthrough inputs unchanged or with additional keys.
|
||||
|
||||
This runnable behaves almost like the identity function, except that it
|
||||
This Runnable behaves almost like the identity function, except that it
|
||||
can be configured to add additional keys to the output, if the input is a
|
||||
dict.
|
||||
|
||||
@@ -73,6 +87,13 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
chains. The chains rely on simple lambdas to make the examples easy to execute
|
||||
and experiment with.
|
||||
|
||||
Parameters:
|
||||
func (Callable[[Other], None], optional): Function to be called with the input.
|
||||
afunc (Callable[[Other], Awaitable[None]], optional): Async function to
|
||||
be called with the input.
|
||||
input_type (Optional[Type[Other]], optional): Type of the input.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -199,10 +220,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
"""Merge the Dict input with the output produced by the mapping argument.
|
||||
|
||||
Args:
|
||||
mapping: A mapping from keys to runnables or callables.
|
||||
**kwargs: Runnable, Callable or a Mapping from keys to Runnables
|
||||
or Callables.
|
||||
|
||||
Returns:
|
||||
A runnable that merges the Dict input with the output produced by the
|
||||
A Runnable that merges the Dict input with the output produced by the
|
||||
mapping argument.
|
||||
"""
|
||||
return RunnableAssign(RunnableParallel(kwargs))
|
||||
@@ -336,6 +358,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
these with the original data, introducing new key-value pairs based
|
||||
on the mapper's logic.
|
||||
|
||||
Parameters:
|
||||
mapper (RunnableParallel[Dict[str, Any]]): A `RunnableParallel` instance
|
||||
that will be used to transform the input dictionary.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
@@ -627,11 +653,15 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
"""Runnable that picks keys from Dict[str, Any] inputs.
|
||||
|
||||
RunnablePick class represents a runnable that selectively picks keys from a
|
||||
RunnablePick class represents a Runnable that selectively picks keys from a
|
||||
dictionary input. It allows you to specify one or more keys to extract
|
||||
from the input dictionary. It returns a new dictionary containing only
|
||||
the selected keys.
|
||||
|
||||
Parameters:
|
||||
keys (Union[str, List[str]]): A single key or a list of keys to pick from
|
||||
the input dictionary.
|
||||
|
||||
Example :
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
@@ -112,7 +112,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
"""Whether to add jitter to the exponential backoff."""
|
||||
|
||||
max_attempt_number: int = 3
|
||||
"""The maximum number of attempts to retry the runnable."""
|
||||
"""The maximum number of attempts to retry the Runnable."""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
|
||||
@@ -38,7 +38,7 @@ class RouterInput(TypedDict):
|
||||
|
||||
Attributes:
|
||||
key: The key to route on.
|
||||
input: The input to pass to the selected runnable.
|
||||
input: The input to pass to the selected Runnable.
|
||||
"""
|
||||
|
||||
key: str
|
||||
@@ -50,6 +50,9 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
Runnable that routes to a set of Runnables based on Input['key'].
|
||||
Returns the output of the selected Runnable.
|
||||
|
||||
Parameters:
|
||||
runnables: A mapping of keys to Runnables.
|
||||
|
||||
For example,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Module contains typedefs that are used with runnables."""
|
||||
"""Module contains typedefs that are used with Runnables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -11,7 +11,7 @@ class EventData(TypedDict, total=False):
|
||||
"""Data associated with a streaming event."""
|
||||
|
||||
input: Any
|
||||
"""The input passed to the runnable that generated the event.
|
||||
"""The input passed to the Runnable that generated the event.
|
||||
|
||||
Inputs will sometimes be available at the *START* of the Runnable, and
|
||||
sometimes at the *END* of the Runnable.
|
||||
@@ -85,40 +85,43 @@ class BaseStreamEvent(TypedDict):
|
||||
event: str
|
||||
"""Event names are of the format: on_[runnable_type]_(start|stream|end).
|
||||
|
||||
Runnable types are one of:
|
||||
* llm - used by non chat models
|
||||
* chat_model - used by chat models
|
||||
* prompt -- e.g., ChatPromptTemplate
|
||||
* tool -- from tools defined via @tool decorator or inheriting from Tool/BaseTool
|
||||
* chain - most Runnables are of this type
|
||||
Runnable types are one of:
|
||||
|
||||
- **llm** - used by non chat models
|
||||
- **chat_model** - used by chat models
|
||||
- **prompt** -- e.g., ChatPromptTemplate
|
||||
- **tool** -- from tools defined via @tool decorator or inheriting
|
||||
from Tool/BaseTool
|
||||
- **chain** - most Runnables are of this type
|
||||
|
||||
Further, the events are categorized as one of:
|
||||
* start - when the runnable starts
|
||||
* stream - when the runnable is streaming
|
||||
* end - when the runnable ends
|
||||
|
||||
- **start** - when the Runnable starts
|
||||
- **stream** - when the Runnable is streaming
|
||||
- **end* - when the Runnable ends
|
||||
|
||||
start, stream and end are associated with slightly different `data` payload.
|
||||
|
||||
Please see the documentation for `EventData` for more details.
|
||||
"""
|
||||
run_id: str
|
||||
"""An randomly generated ID to keep track of the execution of the given runnable.
|
||||
"""An randomly generated ID to keep track of the execution of the given Runnable.
|
||||
|
||||
Each child runnable that gets invoked as part of the execution of a parent runnable
|
||||
Each child Runnable that gets invoked as part of the execution of a parent Runnable
|
||||
is assigned its own unique ID.
|
||||
"""
|
||||
tags: NotRequired[List[str]]
|
||||
"""Tags associated with the runnable that generated this event.
|
||||
"""Tags associated with the Runnable that generated this event.
|
||||
|
||||
Tags are always inherited from parent runnables.
|
||||
Tags are always inherited from parent Runnables.
|
||||
|
||||
Tags can either be bound to a runnable using `.with_config({"tags": ["hello"]})`
|
||||
Tags can either be bound to a Runnable using `.with_config({"tags": ["hello"]})`
|
||||
or passed at run time using `.astream_events(..., {"tags": ["hello"]})`.
|
||||
"""
|
||||
metadata: NotRequired[Dict[str, Any]]
|
||||
"""Metadata associated with the runnable that generated this event.
|
||||
"""Metadata associated with the Runnable that generated this event.
|
||||
|
||||
Metadata can either be bound to a runnable using
|
||||
Metadata can either be bound to a Runnable using
|
||||
|
||||
`.with_config({"metadata": { "foo": "bar" }})`
|
||||
|
||||
@@ -150,21 +153,20 @@ class StandardStreamEvent(BaseStreamEvent):
|
||||
The contents of the event data depend on the event type.
|
||||
"""
|
||||
name: str
|
||||
"""The name of the runnable that generated the event."""
|
||||
"""The name of the Runnable that generated the event."""
|
||||
|
||||
|
||||
class CustomStreamEvent(BaseStreamEvent):
|
||||
"""A custom stream event created by the user.
|
||||
"""Custom stream event created by the user.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
|
||||
# Overwrite the event field to be more specific.
|
||||
event: Literal["on_custom_event"] # type: ignore[misc]
|
||||
|
||||
"""The event type."""
|
||||
name: str
|
||||
"""A user defined name for the event."""
|
||||
"""User defined name for the event."""
|
||||
data: Any
|
||||
"""The data associated with the event. Free form and can be anything."""
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ Output = TypeVar("Output", covariant=True)
|
||||
|
||||
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
"""Run a coroutine with a semaphore.
|
||||
|
||||
Args:
|
||||
semaphore: The semaphore to use.
|
||||
coro: The coroutine to run.
|
||||
@@ -59,7 +60,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
|
||||
|
||||
Args:
|
||||
n: The number of coroutines to run concurrently.
|
||||
coros: The coroutines to run.
|
||||
*coros: The coroutines to run.
|
||||
|
||||
Returns:
|
||||
The results of the coroutines.
|
||||
@@ -73,7 +74,14 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
|
||||
|
||||
|
||||
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable accepts a run_manager argument."""
|
||||
"""Check if a callable accepts a run_manager argument.
|
||||
|
||||
Args:
|
||||
callable: The callable to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the callable accepts a run_manager argument, False otherwise.
|
||||
"""
|
||||
try:
|
||||
return signature(callable).parameters.get("run_manager") is not None
|
||||
except ValueError:
|
||||
@@ -81,7 +89,14 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||
|
||||
|
||||
def accepts_config(callable: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable accepts a config argument."""
|
||||
"""Check if a callable accepts a config argument.
|
||||
|
||||
Args:
|
||||
callable: The callable to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the callable accepts a config argument, False otherwise.
|
||||
"""
|
||||
try:
|
||||
return signature(callable).parameters.get("config") is not None
|
||||
except ValueError:
|
||||
@@ -89,7 +104,14 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
|
||||
|
||||
|
||||
def accepts_context(callable: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable accepts a context argument."""
|
||||
"""Check if a callable accepts a context argument.
|
||||
|
||||
Args:
|
||||
callable: The callable to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the callable accepts a context argument, False otherwise.
|
||||
"""
|
||||
try:
|
||||
return signature(callable).parameters.get("context") is not None
|
||||
except ValueError:
|
||||
@@ -100,10 +122,24 @@ class IsLocalDict(ast.NodeVisitor):
|
||||
"""Check if a name is a local dict."""
|
||||
|
||||
def __init__(self, name: str, keys: Set[str]) -> None:
|
||||
"""Initialize the visitor.
|
||||
|
||||
Args:
|
||||
name: The name to check.
|
||||
keys: The keys to populate.
|
||||
"""
|
||||
self.name = name
|
||||
self.keys = keys
|
||||
|
||||
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
||||
"""Visit a subscript node.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if (
|
||||
isinstance(node.ctx, ast.Load)
|
||||
and isinstance(node.value, ast.Name)
|
||||
@@ -115,6 +151,14 @@ class IsLocalDict(ast.NodeVisitor):
|
||||
self.keys.add(node.slice.value)
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> Any:
|
||||
"""Visit a call node.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if (
|
||||
isinstance(node.func, ast.Attribute)
|
||||
and isinstance(node.func.value, ast.Name)
|
||||
@@ -135,18 +179,42 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
||||
self.keys: Set[str] = set()
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if not node.args.args:
|
||||
return
|
||||
input_arg_name = node.args.args[0].arg
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node.body)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
"""Visit a function definition.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if not node.args.args:
|
||||
return
|
||||
input_arg_name = node.args.args[0].arg
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||
"""Visit an async function definition.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if not node.args.args:
|
||||
return
|
||||
input_arg_name = node.args.args[0].arg
|
||||
@@ -161,12 +229,28 @@ class NonLocals(ast.NodeVisitor):
|
||||
self.stores: Set[str] = set()
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> Any:
|
||||
"""Visit a name node.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
self.loads.add(node.id)
|
||||
elif isinstance(node.ctx, ast.Store):
|
||||
self.stores.add(node.id)
|
||||
|
||||
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
||||
"""Visit an attribute node.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
parent = node.value
|
||||
attr_expr = node.attr
|
||||
@@ -185,16 +269,40 @@ class FunctionNonLocals(ast.NodeVisitor):
|
||||
self.nonlocals: Set[str] = set()
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
"""Visit a function definition.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
visitor = NonLocals()
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||
"""Visit an async function definition.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
visitor = NonLocals()
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
visitor = NonLocals()
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
@@ -209,14 +317,29 @@ class GetLambdaSource(ast.NodeVisitor):
|
||||
self.count = 0
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function."""
|
||||
"""Visit a lambda function.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
self.count += 1
|
||||
if hasattr(ast, "unparse"):
|
||||
self.source = ast.unparse(node)
|
||||
|
||||
|
||||
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
"""Get the keys of the first argument of a function if it is a dict."""
|
||||
"""Get the keys of the first argument of a function if it is a dict.
|
||||
|
||||
Args:
|
||||
func: The function to check.
|
||||
|
||||
Returns:
|
||||
Optional[List[str]]: The keys of the first argument if it is a dict,
|
||||
None otherwise.
|
||||
"""
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
@@ -231,10 +354,10 @@ def get_lambda_source(func: Callable) -> Optional[str]:
|
||||
"""Get the source code of a lambda function.
|
||||
|
||||
Args:
|
||||
func: a callable that can be a lambda function
|
||||
func: a Callable that can be a lambda function.
|
||||
|
||||
Returns:
|
||||
str: the source code of the lambda function
|
||||
str: the source code of the lambda function.
|
||||
"""
|
||||
try:
|
||||
name = func.__name__ if func.__name__ != "<lambda>" else None
|
||||
@@ -251,7 +374,14 @@ def get_lambda_source(func: Callable) -> Optional[str]:
|
||||
|
||||
|
||||
def get_function_nonlocals(func: Callable) -> List[Any]:
|
||||
"""Get the nonlocal variables accessed by a function."""
|
||||
"""Get the nonlocal variables accessed by a function.
|
||||
|
||||
Args:
|
||||
func: The function to check.
|
||||
|
||||
Returns:
|
||||
List[Any]: The nonlocal variables accessed by the function.
|
||||
"""
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
@@ -283,11 +413,11 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
|
||||
"""Indent all lines of text after the first line.
|
||||
|
||||
Args:
|
||||
text: The text to indent
|
||||
prefix: Used to determine the number of spaces to indent
|
||||
text: The text to indent.
|
||||
prefix: Used to determine the number of spaces to indent.
|
||||
|
||||
Returns:
|
||||
str: The indented text
|
||||
str: The indented text.
|
||||
"""
|
||||
n_spaces = len(prefix)
|
||||
spaces = " " * n_spaces
|
||||
@@ -341,7 +471,14 @@ Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
|
||||
|
||||
|
||||
def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
||||
"""Add a sequence of addable objects together."""
|
||||
"""Add a sequence of addable objects together.
|
||||
|
||||
Args:
|
||||
addables: The addable objects to add.
|
||||
|
||||
Returns:
|
||||
Optional[Addable]: The result of adding the addable objects.
|
||||
"""
|
||||
final = None
|
||||
for chunk in addables:
|
||||
if final is None:
|
||||
@@ -352,7 +489,14 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
||||
|
||||
|
||||
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
||||
"""Asynchronously add a sequence of addable objects together."""
|
||||
"""Asynchronously add a sequence of addable objects together.
|
||||
|
||||
Args:
|
||||
addables: The addable objects to add.
|
||||
|
||||
Returns:
|
||||
Optional[Addable]: The result of adding the addable objects.
|
||||
"""
|
||||
final = None
|
||||
async for chunk in addables:
|
||||
if final is None:
|
||||
@@ -363,7 +507,15 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
||||
|
||||
|
||||
class ConfigurableField(NamedTuple):
|
||||
"""Field that can be configured by the user."""
|
||||
"""Field that can be configured by the user.
|
||||
|
||||
Parameters:
|
||||
id: The unique identifier of the field.
|
||||
name: The name of the field. Defaults to None.
|
||||
description: The description of the field. Defaults to None.
|
||||
annotation: The annotation of the field. Defaults to None.
|
||||
is_shared: Whether the field is shared. Defaults to False.
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
@@ -377,7 +529,16 @@ class ConfigurableField(NamedTuple):
|
||||
|
||||
|
||||
class ConfigurableFieldSingleOption(NamedTuple):
|
||||
"""Field that can be configured by the user with a default value."""
|
||||
"""Field that can be configured by the user with a default value.
|
||||
|
||||
Parameters:
|
||||
id: The unique identifier of the field.
|
||||
options: The options for the field.
|
||||
default: The default value for the field.
|
||||
name: The name of the field. Defaults to None.
|
||||
description: The description of the field. Defaults to None.
|
||||
is_shared: Whether the field is shared. Defaults to False.
|
||||
"""
|
||||
|
||||
id: str
|
||||
options: Mapping[str, Any]
|
||||
@@ -392,7 +553,16 @@ class ConfigurableFieldSingleOption(NamedTuple):
|
||||
|
||||
|
||||
class ConfigurableFieldMultiOption(NamedTuple):
|
||||
"""Field that can be configured by the user with multiple default values."""
|
||||
"""Field that can be configured by the user with multiple default values.
|
||||
|
||||
Parameters:
|
||||
id: The unique identifier of the field.
|
||||
options: The options for the field.
|
||||
default: The default values for the field.
|
||||
name: The name of the field. Defaults to None.
|
||||
description: The description of the field. Defaults to None.
|
||||
is_shared: Whether the field is shared. Defaults to False.
|
||||
"""
|
||||
|
||||
id: str
|
||||
options: Mapping[str, Any]
|
||||
@@ -412,7 +582,17 @@ AnyConfigurableField = Union[
|
||||
|
||||
|
||||
class ConfigurableFieldSpec(NamedTuple):
|
||||
"""Field that can be configured by the user. It is a specification of a field."""
|
||||
"""Field that can be configured by the user. It is a specification of a field.
|
||||
|
||||
Parameters:
|
||||
id: The unique identifier of the field.
|
||||
annotation: The annotation of the field.
|
||||
name: The name of the field. Defaults to None.
|
||||
description: The description of the field. Defaults to None.
|
||||
default: The default value for the field. Defaults to None.
|
||||
is_shared: Whether the field is shared. Defaults to False.
|
||||
dependencies: The dependencies of the field. Defaults to None.
|
||||
"""
|
||||
|
||||
id: str
|
||||
annotation: Any
|
||||
@@ -427,7 +607,17 @@ class ConfigurableFieldSpec(NamedTuple):
|
||||
def get_unique_config_specs(
|
||||
specs: Iterable[ConfigurableFieldSpec],
|
||||
) -> List[ConfigurableFieldSpec]:
|
||||
"""Get the unique config specs from a sequence of config specs."""
|
||||
"""Get the unique config specs from a sequence of config specs.
|
||||
|
||||
Args:
|
||||
specs: The config specs.
|
||||
|
||||
Returns:
|
||||
List[ConfigurableFieldSpec]: The unique config specs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the runnable sequence contains conflicting config specs.
|
||||
"""
|
||||
grouped = groupby(
|
||||
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
|
||||
)
|
||||
@@ -542,7 +732,15 @@ def _create_model_cached(
|
||||
def is_async_generator(
|
||||
func: Any,
|
||||
) -> TypeGuard[Callable[..., AsyncIterator]]:
|
||||
"""Check if a function is an async generator."""
|
||||
"""Check if a function is an async generator.
|
||||
|
||||
Args:
|
||||
func: The function to check.
|
||||
|
||||
Returns:
|
||||
TypeGuard[Callable[..., AsyncIterator]: True if the function is
|
||||
an async generator, False otherwise.
|
||||
"""
|
||||
return (
|
||||
inspect.isasyncgenfunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
@@ -553,7 +751,15 @@ def is_async_generator(
|
||||
def is_async_callable(
|
||||
func: Any,
|
||||
) -> TypeGuard[Callable[..., Awaitable]]:
|
||||
"""Check if a function is async."""
|
||||
"""Check if a function is async.
|
||||
|
||||
Args:
|
||||
func: The function to check.
|
||||
|
||||
Returns:
|
||||
TypeGuard[Callable[..., Awaitable]: True if the function is async,
|
||||
False otherwise.
|
||||
"""
|
||||
return (
|
||||
asyncio.iscoroutinefunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -62,7 +62,21 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run."""
|
||||
"""Start a trace for an LLM run.
|
||||
|
||||
Args:
|
||||
serialized: The serialized model.
|
||||
messages: The messages to start the chat with.
|
||||
run_id: The run ID.
|
||||
tags: The tags for the run. Defaults to None.
|
||||
parent_run_id: The parent run ID. Defaults to None.
|
||||
metadata: The metadata for the run. Defaults to None.
|
||||
name: The name of the run.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
chat_model_run = self._create_chat_model_run(
|
||||
serialized=serialized,
|
||||
messages=messages,
|
||||
@@ -89,7 +103,21 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run."""
|
||||
"""Start a trace for an LLM run.
|
||||
|
||||
Args:
|
||||
serialized: The serialized model.
|
||||
prompts: The prompts to start the LLM with.
|
||||
run_id: The run ID.
|
||||
tags: The tags for the run. Defaults to None.
|
||||
parent_run_id: The parent run ID. Defaults to None.
|
||||
metadata: The metadata for the run. Defaults to None.
|
||||
name: The name of the run.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
llm_run = self._create_llm_run(
|
||||
serialized=serialized,
|
||||
prompts=prompts,
|
||||
@@ -113,7 +141,18 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
"""Run on new LLM token. Only available when streaming is enabled.
|
||||
|
||||
Args:
|
||||
token: The token.
|
||||
chunk: The chunk. Defaults to None.
|
||||
run_id: The run ID.
|
||||
parent_run_id: The parent run ID. Defaults to None.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
# "chat_model" is only used for the experimental new streaming_events format.
|
||||
# This change should not affect any existing tracers.
|
||||
llm_run = self._llm_run_with_token_event(
|
||||
@@ -133,6 +172,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run on retry.
|
||||
|
||||
Args:
|
||||
retry_state: The retry state.
|
||||
run_id: The run ID.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
llm_run = self._llm_run_with_retry_event(
|
||||
retry_state=retry_state,
|
||||
run_id=run_id,
|
||||
@@ -140,7 +189,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
return llm_run
|
||||
|
||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
"""End a trace for an LLM run."""
|
||||
"""End a trace for an LLM run.
|
||||
|
||||
Args:
|
||||
response: The response.
|
||||
run_id: The run ID.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
# "chat_model" is only used for the experimental new streaming_events format.
|
||||
# This change should not affect any existing tracers.
|
||||
llm_run = self._complete_llm_run(
|
||||
@@ -158,7 +216,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Handle an error for an LLM run."""
|
||||
"""Handle an error for an LLM run.
|
||||
|
||||
Args:
|
||||
error: The error.
|
||||
run_id: The run ID.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
# "chat_model" is only used for the experimental new streaming_events format.
|
||||
# This change should not affect any existing tracers.
|
||||
llm_run = self._errored_llm_run(
|
||||
@@ -182,7 +249,22 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for a chain run."""
|
||||
"""Start a trace for a chain run.
|
||||
|
||||
Args:
|
||||
serialized: The serialized chain.
|
||||
inputs: The inputs for the chain.
|
||||
run_id: The run ID.
|
||||
tags: The tags for the run. Defaults to None.
|
||||
parent_run_id: The parent run ID. Defaults to None.
|
||||
metadata: The metadata for the run. Defaults to None.
|
||||
run_type: The type of the run. Defaults to None.
|
||||
name: The name of the run.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
chain_run = self._create_chain_run(
|
||||
serialized=serialized,
|
||||
inputs=inputs,
|
||||
@@ -206,7 +288,17 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""End a trace for a chain run."""
|
||||
"""End a trace for a chain run.
|
||||
|
||||
Args:
|
||||
outputs: The outputs for the chain.
|
||||
run_id: The run ID.
|
||||
inputs: The inputs for the chain. Defaults to None.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
chain_run = self._complete_chain_run(
|
||||
outputs=outputs,
|
||||
run_id=run_id,
|
||||
@@ -225,7 +317,17 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Handle an error for a chain run."""
|
||||
"""Handle an error for a chain run.
|
||||
|
||||
Args:
|
||||
error: The error.
|
||||
inputs: The inputs for the chain. Defaults to None.
|
||||
run_id: The run ID.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
chain_run = self._errored_chain_run(
|
||||
error=error,
|
||||
run_id=run_id,
|
||||
@@ -249,7 +351,22 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for a tool run."""
|
||||
"""Start a trace for a tool run.
|
||||
|
||||
Args:
|
||||
serialized: The serialized tool.
|
||||
input_str: The input string.
|
||||
run_id: The run ID.
|
||||
tags: The tags for the run. Defaults to None.
|
||||
parent_run_id: The parent run ID. Defaults to None.
|
||||
metadata: The metadata for the run. Defaults to None.
|
||||
name: The name of the run.
|
||||
inputs: The inputs for the tool.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
tool_run = self._create_tool_run(
|
||||
serialized=serialized,
|
||||
input_str=input_str,
|
||||
@@ -266,7 +383,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
return tool_run
|
||||
|
||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
"""End a trace for a tool run."""
|
||||
"""End a trace for a tool run.
|
||||
|
||||
Args:
|
||||
output: The output for the tool.
|
||||
run_id: The run ID.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
tool_run = self._complete_tool_run(
|
||||
output=output,
|
||||
run_id=run_id,
|
||||
@@ -283,7 +409,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Handle an error for a tool run."""
|
||||
"""Handle an error for a tool run.
|
||||
|
||||
Args:
|
||||
error: The error.
|
||||
run_id: The run ID.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
tool_run = self._errored_tool_run(
|
||||
error=error,
|
||||
run_id=run_id,
|
||||
@@ -304,7 +439,21 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run when Retriever starts running."""
|
||||
"""Run when the Retriever starts running.
|
||||
|
||||
Args:
|
||||
serialized: The serialized retriever.
|
||||
query: The query.
|
||||
run_id: The run ID.
|
||||
parent_run_id: The parent run ID. Defaults to None.
|
||||
tags: The tags for the run. Defaults to None.
|
||||
metadata: The metadata for the run. Defaults to None.
|
||||
name: The name of the run.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
retrieval_run = self._create_retrieval_run(
|
||||
serialized=serialized,
|
||||
query=query,
|
||||
@@ -326,7 +475,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run when Retriever errors."""
|
||||
"""Run when Retriever errors.
|
||||
|
||||
Args:
|
||||
error: The error.
|
||||
run_id: The run ID.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
retrieval_run = self._errored_retrieval_run(
|
||||
error=error,
|
||||
run_id=run_id,
|
||||
@@ -339,7 +497,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
def on_retriever_end(
|
||||
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
||||
) -> Run:
|
||||
"""Run when Retriever ends running."""
|
||||
"""Run when the Retriever ends running.
|
||||
|
||||
Args:
|
||||
documents: The documents.
|
||||
run_id: The run ID.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
The run.
|
||||
"""
|
||||
retrieval_run = self._complete_retrieval_run(
|
||||
documents=documents,
|
||||
run_id=run_id,
|
||||
|
||||
@@ -68,8 +68,8 @@ def tracing_v2_enabled(
|
||||
client (LangSmithClient, optional): The client of the langsmith.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
Yields:
|
||||
LangChainTracer: The LangChain tracer.
|
||||
|
||||
Example:
|
||||
>>> with tracing_v2_enabled():
|
||||
@@ -100,7 +100,7 @@ def tracing_v2_enabled(
|
||||
def collect_runs() -> Generator[RunCollectorCallbackHandler, None, None]:
|
||||
"""Collect all run traces in context.
|
||||
|
||||
Returns:
|
||||
Yields:
|
||||
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -46,7 +46,8 @@ SCHEMA_FORMAT_TYPE = Literal["original", "streaming_events"]
|
||||
|
||||
class _TracerCore(ABC):
|
||||
"""
|
||||
Abstract base class for tracers
|
||||
Abstract base class for tracers.
|
||||
|
||||
This class provides common methods, and reusable methods for tracers.
|
||||
"""
|
||||
|
||||
@@ -65,17 +66,18 @@ class _TracerCore(ABC):
|
||||
Args:
|
||||
_schema_format: Primarily changes how the inputs and outputs are
|
||||
handled. For internal use only. This API will change.
|
||||
|
||||
- 'original' is the format used by all current tracers.
|
||||
This format is slightly inconsistent with respect to inputs
|
||||
and outputs.
|
||||
This format is slightly inconsistent with respect to inputs
|
||||
and outputs.
|
||||
- 'streaming_events' is used for supporting streaming events,
|
||||
for internal usage. It will likely change in the future, or
|
||||
be deprecated entirely in favor of a dedicated async tracer
|
||||
for streaming events.
|
||||
for internal usage. It will likely change in the future, or
|
||||
be deprecated entirely in favor of a dedicated async tracer
|
||||
for streaming events.
|
||||
- 'original+chat' is a format that is the same as 'original'
|
||||
except it does NOT raise an attribute error on_chat_model_start
|
||||
except it does NOT raise an attribute error on_chat_model_start
|
||||
kwargs: Additional keyword arguments that will be passed to
|
||||
the super class.
|
||||
the superclass.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._schema_format = _schema_format # For internal use only API will change.
|
||||
@@ -207,7 +209,7 @@ class _TracerCore(ABC):
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Create a llm run"""
|
||||
"""Create a llm run."""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
@@ -234,7 +236,7 @@ class _TracerCore(ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""
|
||||
Append token event to LLM run and return the run
|
||||
Append token event to LLM run and return the run.
|
||||
"""
|
||||
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
|
||||
event_kwargs: Dict[str, Any] = {"token": token}
|
||||
@@ -314,7 +316,7 @@ class _TracerCore(ABC):
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Create a chain Run"""
|
||||
"""Create a chain Run."""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
|
||||
@@ -104,7 +104,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
|
||||
"""Evaluate the run in the project.
|
||||
|
||||
Parameters
|
||||
Args:
|
||||
----------
|
||||
run : Run
|
||||
The run to be evaluated.
|
||||
@@ -200,7 +200,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Run the evaluator on the run.
|
||||
|
||||
Parameters
|
||||
Args:
|
||||
----------
|
||||
run : Run
|
||||
The run to be evaluated.
|
||||
|
||||
@@ -52,7 +52,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunInfo(TypedDict):
|
||||
"""Information about a run."""
|
||||
"""Information about a run.
|
||||
|
||||
This is used to keep track of the metadata associated with a run.
|
||||
|
||||
Parameters:
|
||||
name: The name of the run.
|
||||
tags: The tags associated with the run.
|
||||
metadata: The metadata associated with the run.
|
||||
run_type: The type of the run.
|
||||
inputs: The inputs to the run.
|
||||
parent_run_id: The ID of the parent run.
|
||||
"""
|
||||
|
||||
name: str
|
||||
tags: List[str]
|
||||
@@ -150,7 +161,19 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
async def tap_output_aiter(
|
||||
self, run_id: UUID, output: AsyncIterator[T]
|
||||
) -> AsyncIterator[T]:
|
||||
"""Tap the output aiter."""
|
||||
"""Tap the output aiter.
|
||||
|
||||
This method is used to tap the output of a Runnable that produces
|
||||
an async iterator. It is used to generate stream events for the
|
||||
output of the Runnable.
|
||||
|
||||
Args:
|
||||
run_id: The ID of the run.
|
||||
output: The output of the Runnable.
|
||||
|
||||
Yields:
|
||||
T: The output of the Runnable.
|
||||
"""
|
||||
sentinel = object()
|
||||
# atomic check and set
|
||||
tap = self.is_tapped.setdefault(run_id, sentinel)
|
||||
@@ -192,7 +215,15 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
yield chunk
|
||||
|
||||
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
|
||||
"""Tap the output aiter."""
|
||||
"""Tap the output aiter.
|
||||
|
||||
Args:
|
||||
run_id: The ID of the run.
|
||||
output: The output of the Runnable.
|
||||
|
||||
Yields:
|
||||
T: The output of the Runnable.
|
||||
"""
|
||||
sentinel = object()
|
||||
# atomic check and set
|
||||
tap = self.is_tapped.setdefault(run_id, sentinel)
|
||||
|
||||
@@ -32,7 +32,12 @@ _EXECUTOR: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
|
||||
def log_error_once(method: str, exception: Exception) -> None:
|
||||
"""Log an error once."""
|
||||
"""Log an error once.
|
||||
|
||||
Args:
|
||||
method: The method that raised the exception.
|
||||
exception: The exception that was raised.
|
||||
"""
|
||||
global _LOGGED
|
||||
if (method, type(exception)) in _LOGGED:
|
||||
return
|
||||
@@ -82,7 +87,15 @@ class LangChainTracer(BaseTracer):
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
"""Initialize the LangChain tracer.
|
||||
|
||||
Args:
|
||||
example_id: The example ID.
|
||||
project_name: The project name. Defaults to the tracer project.
|
||||
client: The client. Defaults to the global client.
|
||||
tags: The tags. Defaults to an empty list.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.example_id = (
|
||||
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||
@@ -104,7 +117,21 @@ class LangChainTracer(BaseTracer):
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run."""
|
||||
"""Start a trace for an LLM run.
|
||||
|
||||
Args:
|
||||
serialized: The serialized model.
|
||||
messages: The messages.
|
||||
run_id: The run ID.
|
||||
tags: The tags. Defaults to None.
|
||||
parent_run_id: The parent run ID. Defaults to None.
|
||||
metadata: The metadata. Defaults to None.
|
||||
name: The name. Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Run: The run.
|
||||
"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
@@ -130,7 +157,15 @@ class LangChainTracer(BaseTracer):
|
||||
self.latest_run = run_
|
||||
|
||||
def get_run_url(self) -> str:
|
||||
"""Get the LangSmith root run URL"""
|
||||
"""Get the LangSmith root run URL.
|
||||
|
||||
Returns:
|
||||
str: The LangSmith root run URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If no traced run is found.
|
||||
ValueError: If the run URL cannot be found.
|
||||
"""
|
||||
if not self.latest_run:
|
||||
raise ValueError("No traced run found.")
|
||||
# If this is the first run in a project, the project may not yet be created.
|
||||
|
||||
@@ -189,12 +189,15 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
handled.
|
||||
**For internal use only. This API will change.**
|
||||
- 'original' is the format used by all current tracers.
|
||||
This format is slightly inconsistent with respect to inputs
|
||||
and outputs.
|
||||
This format is slightly inconsistent with respect to inputs
|
||||
and outputs.
|
||||
- 'streaming_events' is used for supporting streaming events,
|
||||
for internal usage. It will likely change in the future, or
|
||||
be deprecated entirely in favor of a dedicated async tracer
|
||||
for streaming events.
|
||||
for internal usage. It will likely change in the future, or
|
||||
be deprecated entirely in favor of a dedicated async tracer
|
||||
for streaming events.
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid schema format is provided (internal use only).
|
||||
"""
|
||||
if _schema_format not in {"original", "streaming_events"}:
|
||||
raise ValueError(
|
||||
@@ -224,7 +227,15 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
return self.receive_stream.__aiter__()
|
||||
|
||||
def send(self, *ops: Dict[str, Any]) -> bool:
|
||||
"""Send a patch to the stream, return False if the stream is closed."""
|
||||
"""Send a patch to the stream, return False if the stream is closed.
|
||||
|
||||
Args:
|
||||
*ops: The operations to send to the stream.
|
||||
|
||||
Returns:
|
||||
bool: True if the patch was sent successfully, False if the stream
|
||||
is closed.
|
||||
"""
|
||||
# We will likely want to wrap this in try / except at some point
|
||||
# to handle exceptions that might arise at run time.
|
||||
# For now we'll let the exception bubble up, and always return
|
||||
@@ -235,7 +246,15 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
async def tap_output_aiter(
|
||||
self, run_id: UUID, output: AsyncIterator[T]
|
||||
) -> AsyncIterator[T]:
|
||||
"""Tap an output async iterator to stream its values to the log."""
|
||||
"""Tap an output async iterator to stream its values to the log.
|
||||
|
||||
Args:
|
||||
run_id: The ID of the run.
|
||||
output: The output async iterator.
|
||||
|
||||
Yields:
|
||||
T: The output value.
|
||||
"""
|
||||
async for chunk in output:
|
||||
# root run is handled in .astream_log()
|
||||
if run_id != self.root_id:
|
||||
@@ -254,7 +273,15 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
yield chunk
|
||||
|
||||
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
|
||||
"""Tap an output async iterator to stream its values to the log."""
|
||||
"""Tap an output async iterator to stream its values to the log.
|
||||
|
||||
Args:
|
||||
run_id: The ID of the run.
|
||||
output: The output iterator.
|
||||
|
||||
Yields:
|
||||
T: The output value.
|
||||
"""
|
||||
for chunk in output:
|
||||
# root run is handled in .astream_log()
|
||||
if run_id != self.root_id:
|
||||
@@ -273,6 +300,14 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
yield chunk
|
||||
|
||||
def include_run(self, run: Run) -> bool:
|
||||
"""Check if a Run should be included in the log.
|
||||
|
||||
Args:
|
||||
run: The Run to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the run should be included, False otherwise.
|
||||
"""
|
||||
if run.id == self.root_id:
|
||||
return False
|
||||
|
||||
@@ -454,7 +489,7 @@ def _get_standardized_inputs(
|
||||
Returns:
|
||||
Valid inputs are only dict. By conventions, inputs always represented
|
||||
invocation using named arguments.
|
||||
A None means that the input is not yet known!
|
||||
None means that the input is not yet known!
|
||||
"""
|
||||
if schema_format == "original":
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -33,11 +33,27 @@ class _SendStream(Generic[T]):
|
||||
self._done = done
|
||||
|
||||
async def send(self, item: T) -> None:
|
||||
"""Schedule the item to be written to the queue using the original loop."""
|
||||
"""Schedule the item to be written to the queue using the original loop.
|
||||
|
||||
This is a coroutine that can be awaited.
|
||||
|
||||
Args:
|
||||
item: The item to write to the queue.
|
||||
"""
|
||||
return self.send_nowait(item)
|
||||
|
||||
def send_nowait(self, item: T) -> None:
|
||||
"""Schedule the item to be written to the queue using the original loop."""
|
||||
"""Schedule the item to be written to the queue using the original loop.
|
||||
|
||||
This is a non-blocking call.
|
||||
|
||||
Args:
|
||||
item: The item to write to the queue.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the event loop is already closed when trying to write
|
||||
to the queue.
|
||||
"""
|
||||
try:
|
||||
self._reader_loop.call_soon_threadsafe(self._queue.put_nowait, item)
|
||||
except RuntimeError:
|
||||
@@ -45,11 +61,18 @@ class _SendStream(Generic[T]):
|
||||
raise # Raise the exception if the loop is not closed
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Schedule the done object write the queue using the original loop."""
|
||||
"""Async schedule the done object write the queue using the original loop."""
|
||||
return self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Schedule the done object write the queue using the original loop."""
|
||||
"""Schedule the done object write the queue using the original loop.
|
||||
|
||||
This is a non-blocking call.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the event loop is already closed when trying to write
|
||||
to the queue.
|
||||
"""
|
||||
try:
|
||||
self._reader_loop.call_soon_threadsafe(self._queue.put_nowait, self._done)
|
||||
except RuntimeError:
|
||||
@@ -87,7 +110,7 @@ class _MemoryStream(Generic[T]):
|
||||
|
||||
This implementation is meant to be used with a single writer and a single reader.
|
||||
|
||||
This is an internal implementation to LangChain please do not use it directly.
|
||||
This is an internal implementation to LangChain. Please do not use it directly.
|
||||
"""
|
||||
|
||||
def __init__(self, loop: AbstractEventLoop) -> None:
|
||||
@@ -103,11 +126,19 @@ class _MemoryStream(Generic[T]):
|
||||
self._done = object()
|
||||
|
||||
def get_send_stream(self) -> _SendStream[T]:
|
||||
"""Get a writer for the channel."""
|
||||
"""Get a writer for the channel.
|
||||
|
||||
Returns:
|
||||
_SendStream: The writer for the channel.
|
||||
"""
|
||||
return _SendStream[T](
|
||||
reader_loop=self._loop, queue=self._queue, done=self._done
|
||||
)
|
||||
|
||||
def get_receive_stream(self) -> _ReceiveStream[T]:
|
||||
"""Get a reader for the channel."""
|
||||
"""Get a reader for the channel.
|
||||
|
||||
Returns:
|
||||
_ReceiveStream: The reader for the channel.
|
||||
"""
|
||||
return _ReceiveStream[T](queue=self._queue, done=self._done)
|
||||
|
||||
@@ -16,7 +16,16 @@ AsyncListener = Union[
|
||||
|
||||
|
||||
class RootListenersTracer(BaseTracer):
|
||||
"""Tracer that calls listeners on run start, end, and error."""
|
||||
"""Tracer that calls listeners on run start, end, and error.
|
||||
|
||||
Parameters:
|
||||
log_missing_parent: Whether to log a warning if the parent is missing.
|
||||
Default is False.
|
||||
config: The runnable config.
|
||||
on_start: The listener to call on run start.
|
||||
on_end: The listener to call on run end.
|
||||
on_error: The listener to call on run error.
|
||||
"""
|
||||
|
||||
log_missing_parent = False
|
||||
|
||||
@@ -28,6 +37,14 @@ class RootListenersTracer(BaseTracer):
|
||||
on_end: Optional[Listener],
|
||||
on_error: Optional[Listener],
|
||||
) -> None:
|
||||
"""Initialize the tracer.
|
||||
|
||||
Args:
|
||||
config: The runnable config.
|
||||
on_start: The listener to call on run start.
|
||||
on_end: The listener to call on run end.
|
||||
on_error: The listener to call on run error
|
||||
"""
|
||||
super().__init__(_schema_format="original+chat")
|
||||
|
||||
self.config = config
|
||||
@@ -63,7 +80,16 @@ class RootListenersTracer(BaseTracer):
|
||||
|
||||
|
||||
class AsyncRootListenersTracer(AsyncBaseTracer):
|
||||
"""Async Tracer that calls listeners on run start, end, and error."""
|
||||
"""Async Tracer that calls listeners on run start, end, and error.
|
||||
|
||||
Parameters:
|
||||
log_missing_parent: Whether to log a warning if the parent is missing.
|
||||
Default is False.
|
||||
config: The runnable config.
|
||||
on_start: The listener to call on run start.
|
||||
on_end: The listener to call on run end.
|
||||
on_error: The listener to call on run error.
|
||||
"""
|
||||
|
||||
log_missing_parent = False
|
||||
|
||||
@@ -75,6 +101,14 @@ class AsyncRootListenersTracer(AsyncBaseTracer):
|
||||
on_end: Optional[AsyncListener],
|
||||
on_error: Optional[AsyncListener],
|
||||
) -> None:
|
||||
"""Initialize the tracer.
|
||||
|
||||
Args:
|
||||
config: The runnable config.
|
||||
on_start: The listener to call on run start.
|
||||
on_end: The listener to call on run end.
|
||||
on_error: The listener to call on run error
|
||||
"""
|
||||
super().__init__(_schema_format="original+chat")
|
||||
|
||||
self.config = config
|
||||
|
||||
@@ -8,13 +8,13 @@ from langchain_core.tracers.schemas import Run
|
||||
|
||||
|
||||
class RunCollectorCallbackHandler(BaseTracer):
|
||||
"""
|
||||
Tracer that collects all nested runs in a list.
|
||||
"""Tracer that collects all nested runs in a list.
|
||||
|
||||
This tracer is useful for inspection and evaluation purposes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str, default="run-collector_callback_handler"
|
||||
example_id : Optional[Union[UUID, str]], default=None
|
||||
The ID of the example being traced. It can be either a UUID or a string.
|
||||
"""
|
||||
@@ -31,6 +31,8 @@ class RunCollectorCallbackHandler(BaseTracer):
|
||||
----------
|
||||
example_id : Optional[Union[UUID, str]], default=None
|
||||
The ID of the example being traced. It can be either a UUID or a string.
|
||||
**kwargs : Any
|
||||
Additional keyword arguments
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.example_id = (
|
||||
|
||||
@@ -112,7 +112,15 @@ class ToolRun(BaseRun):
|
||||
|
||||
|
||||
class Run(BaseRunV2):
|
||||
"""Run schema for the V2 API in the Tracer."""
|
||||
"""Run schema for the V2 API in the Tracer.
|
||||
|
||||
Parameters:
|
||||
child_runs: The child runs.
|
||||
tags: The tags. Default is an empty list.
|
||||
events: The events. Default is an empty list.
|
||||
trace_id: The trace ID. Default is None.
|
||||
dotted_order: The dotted order.
|
||||
"""
|
||||
|
||||
child_runs: List[Run] = Field(default_factory=list)
|
||||
tags: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
@@ -7,15 +7,14 @@ from langchain_core.utils.input import get_bolded_text, get_colored_text
|
||||
|
||||
|
||||
def try_json_stringify(obj: Any, fallback: str) -> str:
|
||||
"""
|
||||
Try to stringify an object to JSON.
|
||||
"""Try to stringify an object to JSON.
|
||||
|
||||
Args:
|
||||
obj: Object to stringify.
|
||||
fallback: Fallback string to return if the object cannot be stringified.
|
||||
|
||||
Returns:
|
||||
A JSON string if the object can be stringified, otherwise the fallback string.
|
||||
|
||||
"""
|
||||
try:
|
||||
return json.dumps(obj, indent=2, ensure_ascii=False)
|
||||
@@ -45,6 +44,8 @@ class FunctionCallbackHandler(BaseTracer):
|
||||
"""Tracer that calls a function with a single str parameter."""
|
||||
|
||||
name: str = "function_callback_handler"
|
||||
"""The name of the tracer. This is used to identify the tracer in the logs.
|
||||
Default is "function_callback_handler"."""
|
||||
|
||||
def __init__(self, function: Callable[[str], None], **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
@@ -54,6 +55,14 @@ class FunctionCallbackHandler(BaseTracer):
|
||||
pass
|
||||
|
||||
def get_parents(self, run: Run) -> List[Run]:
|
||||
"""Get the parents of a run.
|
||||
|
||||
Args:
|
||||
run: The run to get the parents of.
|
||||
|
||||
Returns:
|
||||
A list of parent runs.
|
||||
"""
|
||||
parents = []
|
||||
current_run = run
|
||||
while current_run.parent_run_id:
|
||||
@@ -66,6 +75,14 @@ class FunctionCallbackHandler(BaseTracer):
|
||||
return parents
|
||||
|
||||
def get_breadcrumbs(self, run: Run) -> str:
|
||||
"""Get the breadcrumbs of a run.
|
||||
|
||||
Args:
|
||||
run: The run to get the breadcrumbs of.
|
||||
|
||||
Returns:
|
||||
A string with the breadcrumbs of the run.
|
||||
"""
|
||||
parents = self.get_parents(run)[::-1]
|
||||
string = " > ".join(
|
||||
f"{parent.run_type}:{parent.name}"
|
||||
|
||||
@@ -8,6 +8,17 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
|
||||
dictionaries but has a value of None in 'left'. In such cases, the method uses the
|
||||
value from 'right' for that key in the merged dictionary.
|
||||
|
||||
Args:
|
||||
left: The first dictionary to merge.
|
||||
others: The other dictionaries to merge.
|
||||
|
||||
Returns:
|
||||
The merged dictionary.
|
||||
|
||||
Raises:
|
||||
TypeError: If the key exists in both dictionaries but has a different type.
|
||||
TypeError: If the value has an unsupported type.
|
||||
|
||||
Example:
|
||||
If left = {"function_call": {"arguments": None}} and
|
||||
right = {"function_call": {"arguments": "{\n"}}
|
||||
@@ -46,7 +57,15 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
|
||||
|
||||
|
||||
def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]:
|
||||
"""Add many lists, handling None."""
|
||||
"""Add many lists, handling None.
|
||||
|
||||
Args:
|
||||
left: The first list to merge.
|
||||
others: The other lists to merge.
|
||||
|
||||
Returns:
|
||||
The merged list.
|
||||
"""
|
||||
merged = left.copy() if left is not None else None
|
||||
for other in others:
|
||||
if other is None:
|
||||
@@ -75,6 +94,23 @@ def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]
|
||||
|
||||
|
||||
def merge_obj(left: Any, right: Any) -> Any:
|
||||
"""Merge two objects.
|
||||
|
||||
It handles specific scenarios where a key exists in both
|
||||
dictionaries but has a value of None in 'left'. In such cases, the method uses the
|
||||
value from 'right' for that key in the merged dictionary.
|
||||
|
||||
Args:
|
||||
left: The first object to merge.
|
||||
right: The other object to merge.
|
||||
|
||||
Returns:
|
||||
The merged object.
|
||||
|
||||
Raises:
|
||||
TypeError: If the key exists in both dictionaries but has a different type.
|
||||
ValueError: If the two objects cannot be merged.
|
||||
"""
|
||||
if left is None or right is None:
|
||||
return left if left is not None else right
|
||||
elif type(left) is not type(right):
|
||||
|
||||
@@ -44,6 +44,18 @@ def py_anext(
|
||||
Can be used to compare the built-in implementation of the inner
|
||||
coroutines machinery to C-implementation of __anext__() and send()
|
||||
or throw() on the returned generator.
|
||||
|
||||
Args:
|
||||
iterator: The async iterator to advance.
|
||||
default: The value to return if the iterator is exhausted.
|
||||
If not provided, a StopAsyncIteration exception is raised.
|
||||
|
||||
Returns:
|
||||
The next value from the iterator, or the default value
|
||||
if the iterator is exhausted.
|
||||
|
||||
Raises:
|
||||
TypeError: If the iterator is not an async iterator.
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -71,7 +83,7 @@ def py_anext(
|
||||
|
||||
|
||||
class NoLock:
|
||||
"""Dummy lock that provides the proper interface but no protection"""
|
||||
"""Dummy lock that provides the proper interface but no protection."""
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
pass
|
||||
@@ -88,7 +100,21 @@ async def tee_peer(
|
||||
peers: List[Deque[T]],
|
||||
lock: AsyncContextManager[Any],
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""An individual iterator of a :py:func:`~.tee`"""
|
||||
"""An individual iterator of a :py:func:`~.tee`.
|
||||
|
||||
This function is a generator that yields items from the shared iterator
|
||||
``iterator``. It buffers items until the least advanced iterator has
|
||||
yielded them as well. The buffer is shared with all other peers.
|
||||
|
||||
Args:
|
||||
iterator: The shared iterator.
|
||||
buffer: The buffer for this peer.
|
||||
peers: The buffers of all peers.
|
||||
lock: The lock to synchronise access to the shared buffers.
|
||||
|
||||
Yields:
|
||||
The next item from the shared iterator.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
if not buffer:
|
||||
@@ -204,6 +230,7 @@ class Tee(Generic[T]):
|
||||
return False
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Async close all child iterators."""
|
||||
for child in self._children:
|
||||
await child.aclose()
|
||||
|
||||
@@ -258,7 +285,7 @@ async def abatch_iterate(
|
||||
iterable: The async iterable to batch.
|
||||
|
||||
Returns:
|
||||
An async iterator over the batches
|
||||
An async iterator over the batches.
|
||||
"""
|
||||
batch: List[T] = []
|
||||
async for element in iterable:
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
def curry(func: Callable[..., Any], **curried_kwargs: Any) -> Callable[..., Any]:
|
||||
"""Util that wraps a function and partially applies kwargs to it.
|
||||
Returns a new function whose signature omits the curried variables.
|
||||
|
||||
Args:
|
||||
func: The function to curry.
|
||||
curried_kwargs: Arguments to apply to the function.
|
||||
|
||||
Returns:
|
||||
A new function with curried arguments applied.
|
||||
|
||||
.. versionadded:: 0.2.18
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
new_kwargs = {**curried_kwargs, **kwargs}
|
||||
return await func(*args, **new_kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
new_kwargs = {**curried_kwargs, **kwargs}
|
||||
return func(*args, **new_kwargs)
|
||||
|
||||
sig = inspect.signature(func)
|
||||
# Create a new signature without the curried parameters
|
||||
new_params = [p for name, p in sig.parameters.items() if name not in curried_kwargs]
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
async_wrapper = wraps(func)(async_wrapper)
|
||||
setattr(async_wrapper, "__signature__", sig.replace(parameters=new_params))
|
||||
return async_wrapper
|
||||
else:
|
||||
sync_wrapper = wraps(func)(sync_wrapper)
|
||||
setattr(sync_wrapper, "__signature__", sig.replace(parameters=new_params))
|
||||
return sync_wrapper
|
||||
@@ -36,7 +36,7 @@ def get_from_dict_or_env(
|
||||
env_key: The environment variable to look up if the key is not
|
||||
in the dictionary.
|
||||
default: The default value to return if the key is not in the dictionary
|
||||
or the environment.
|
||||
or the environment. Defaults to None.
|
||||
"""
|
||||
if isinstance(key, (list, tuple)):
|
||||
for k in key:
|
||||
@@ -56,7 +56,22 @@ def get_from_dict_or_env(
|
||||
|
||||
|
||||
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
||||
"""Get a value from a dictionary or an environment variable."""
|
||||
"""Get a value from a dictionary or an environment variable.
|
||||
|
||||
Args:
|
||||
key: The key to look up in the dictionary.
|
||||
env_key: The environment variable to look up if the key is not
|
||||
in the dictionary.
|
||||
default: The default value to return if the key is not in the dictionary
|
||||
or the environment. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The value of the key.
|
||||
|
||||
Raises:
|
||||
ValueError: If the key is not in the dictionary and no default value is
|
||||
provided or if the environment variable is not set.
|
||||
"""
|
||||
if env_key in os.environ and os.environ[env_key]:
|
||||
return os.environ[env_key]
|
||||
elif default is not None:
|
||||
|
||||
@@ -10,7 +10,19 @@ class StrictFormatter(Formatter):
|
||||
def vformat(
|
||||
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
|
||||
) -> str:
|
||||
"""Check that no arguments are provided."""
|
||||
"""Check that no arguments are provided.
|
||||
|
||||
Args:
|
||||
format_string: The format string.
|
||||
args: The arguments.
|
||||
kwargs: The keyword arguments.
|
||||
|
||||
Returns:
|
||||
The formatted string.
|
||||
|
||||
Raises:
|
||||
ValueError: If any arguments are provided.
|
||||
"""
|
||||
if len(args) > 0:
|
||||
raise ValueError(
|
||||
"No arguments should be provided, "
|
||||
@@ -21,6 +33,15 @@ class StrictFormatter(Formatter):
|
||||
def validate_input_variables(
|
||||
self, format_string: str, input_variables: List[str]
|
||||
) -> None:
|
||||
"""Check that all input variables are used in the format string.
|
||||
|
||||
Args:
|
||||
format_string: The format string.
|
||||
input_variables: The input variables.
|
||||
|
||||
Raises:
|
||||
ValueError: If any input variables are not used in the format string.
|
||||
"""
|
||||
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
||||
super().format(format_string, **dummy_inputs)
|
||||
|
||||
|
||||
@@ -55,7 +55,9 @@ class ToolDescription(TypedDict):
|
||||
"""Representation of a callable function to the OpenAI API."""
|
||||
|
||||
type: Literal["function"]
|
||||
"""The type of the tool."""
|
||||
function: FunctionDescription
|
||||
"""The function description."""
|
||||
|
||||
|
||||
def _rm_titles(kv: dict, prev_key: str = "") -> dict:
|
||||
@@ -85,7 +87,19 @@ def convert_pydantic_to_openai_function(
|
||||
description: Optional[str] = None,
|
||||
rm_titles: bool = True,
|
||||
) -> FunctionDescription:
|
||||
"""Converts a Pydantic model to a function description for the OpenAI API."""
|
||||
"""Converts a Pydantic model to a function description for the OpenAI API.
|
||||
|
||||
Args:
|
||||
model: The Pydantic model to convert.
|
||||
name: The name of the function. If not provided, the title of the schema will be
|
||||
used.
|
||||
description: The description of the function. If not provided, the description
|
||||
of the schema will be used.
|
||||
rm_titles: Whether to remove titles from the schema. Defaults to True.
|
||||
|
||||
Returns:
|
||||
The function description.
|
||||
"""
|
||||
schema = dereference_refs(model.schema())
|
||||
schema.pop("definitions", None)
|
||||
title = schema.pop("title", "")
|
||||
@@ -108,7 +122,18 @@ def convert_pydantic_to_openai_tool(
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> ToolDescription:
|
||||
"""Converts a Pydantic model to a function description for the OpenAI API."""
|
||||
"""Converts a Pydantic model to a function description for the OpenAI API.
|
||||
|
||||
Args:
|
||||
model: The Pydantic model to convert.
|
||||
name: The name of the function. If not provided, the title of the schema will be
|
||||
used.
|
||||
description: The description of the function. If not provided, the description
|
||||
of the schema will be used.
|
||||
|
||||
Returns:
|
||||
The tool description.
|
||||
"""
|
||||
function = convert_pydantic_to_openai_function(
|
||||
model, name=name, description=description
|
||||
)
|
||||
@@ -133,12 +158,22 @@ def convert_python_function_to_openai_function(
|
||||
Assumes the Python function has type hints and a docstring with a description. If
|
||||
the docstring has Google Python style argument descriptions, these will be
|
||||
included as well.
|
||||
|
||||
Args:
|
||||
function: The Python function to convert.
|
||||
|
||||
Returns:
|
||||
The OpenAI function description.
|
||||
"""
|
||||
from langchain_core import tools
|
||||
|
||||
func_name = _get_python_function_name(function)
|
||||
model = tools.create_schema_from_function(
|
||||
func_name, function, filter_args=(), parse_docstring=True
|
||||
func_name,
|
||||
function,
|
||||
filter_args=(),
|
||||
parse_docstring=True,
|
||||
error_on_invalid_docstring=False,
|
||||
)
|
||||
return convert_pydantic_to_openai_function(
|
||||
model,
|
||||
@@ -153,7 +188,14 @@ def convert_python_function_to_openai_function(
|
||||
removal="0.3.0",
|
||||
)
|
||||
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
"""Format tool into the OpenAI function API."""
|
||||
"""Format tool into the OpenAI function API.
|
||||
|
||||
Args:
|
||||
tool: The tool to format.
|
||||
|
||||
Returns:
|
||||
The function description.
|
||||
"""
|
||||
if tool.args_schema:
|
||||
return convert_pydantic_to_openai_function(
|
||||
tool.args_schema, name=tool.name, description=tool.description
|
||||
@@ -183,7 +225,14 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
removal="0.3.0",
|
||||
)
|
||||
def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
|
||||
"""Format tool into the OpenAI function API."""
|
||||
"""Format tool into the OpenAI function API.
|
||||
|
||||
Args:
|
||||
tool: The tool to format.
|
||||
|
||||
Returns:
|
||||
The tool description.
|
||||
"""
|
||||
function = format_tool_to_openai_function(tool)
|
||||
return {"type": "function", "function": function}
|
||||
|
||||
@@ -202,6 +251,9 @@ def convert_to_openai_function(
|
||||
Returns:
|
||||
A dict version of the passed in function which is compatible with the
|
||||
OpenAI function-calling API.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function is not in a supported format.
|
||||
"""
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
@@ -280,7 +332,7 @@ def tool_example_to_messages(
|
||||
BaseModels
|
||||
tool_outputs: Optional[List[str]], a list of tool call outputs.
|
||||
Does not need to be provided. If not provided, a placeholder value
|
||||
will be inserted.
|
||||
will be inserted. Defaults to None.
|
||||
|
||||
Returns:
|
||||
A list of messages
|
||||
|
||||
@@ -34,11 +34,11 @@ DEFAULT_LINK_REGEX = (
|
||||
def find_all_links(
|
||||
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
|
||||
) -> List[str]:
|
||||
"""Extract all links from a raw html string.
|
||||
"""Extract all links from a raw HTML string.
|
||||
|
||||
Args:
|
||||
raw_html: original html.
|
||||
pattern: Regex to use for extracting links from raw html.
|
||||
raw_html: original HTML.
|
||||
pattern: Regex to use for extracting links from raw HTML.
|
||||
|
||||
Returns:
|
||||
List[str]: all links
|
||||
@@ -57,20 +57,20 @@ def extract_sub_links(
|
||||
exclude_prefixes: Sequence[str] = (),
|
||||
continue_on_failure: bool = False,
|
||||
) -> List[str]:
|
||||
"""Extract all links from a raw html string and convert into absolute paths.
|
||||
"""Extract all links from a raw HTML string and convert into absolute paths.
|
||||
|
||||
Args:
|
||||
raw_html: original html.
|
||||
url: the url of the html.
|
||||
base_url: the base url to check for outside links against.
|
||||
pattern: Regex to use for extracting links from raw html.
|
||||
raw_html: original HTML.
|
||||
url: the url of the HTML.
|
||||
base_url: the base URL to check for outside links against.
|
||||
pattern: Regex to use for extracting links from raw HTML.
|
||||
prevent_outside: If True, ignore external links which are not children
|
||||
of the base url.
|
||||
of the base URL.
|
||||
exclude_prefixes: Exclude any URLs that start with one of these prefixes.
|
||||
continue_on_failure: If True, continue if parsing a specific link raises an
|
||||
exception. Otherwise, raise the exception.
|
||||
Returns:
|
||||
List[str]: sub links
|
||||
List[str]: sub links.
|
||||
"""
|
||||
base_url_to_use = base_url if base_url is not None else url
|
||||
parsed_base_url = urlparse(base_url_to_use)
|
||||
|
||||
@@ -3,12 +3,27 @@ import mimetypes
|
||||
|
||||
|
||||
def encode_image(image_path: str) -> str:
|
||||
"""Get base64 string from image URI."""
|
||||
"""Get base64 string from image URI.
|
||||
|
||||
Args:
|
||||
image_path: The path to the image.
|
||||
|
||||
Returns:
|
||||
The base64 string of the image.
|
||||
"""
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
def image_to_data_url(image_path: str) -> str:
|
||||
"""Get data URL from image URI.
|
||||
|
||||
Args:
|
||||
image_path: The path to the image.
|
||||
|
||||
Returns:
|
||||
The data URL of the image.
|
||||
"""
|
||||
encoding = encode_image(image_path)
|
||||
mime_type = mimetypes.guess_type(image_path)[0]
|
||||
return f"data:{mime_type};base64,{encoding}"
|
||||
|
||||
@@ -14,7 +14,15 @@ _TEXT_COLOR_MAPPING = {
|
||||
def get_color_mapping(
|
||||
items: List[str], excluded_colors: Optional[List] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Get mapping for items to a support color."""
|
||||
"""Get mapping for items to a support color.
|
||||
|
||||
Args:
|
||||
items: The items to map to colors.
|
||||
excluded_colors: The colors to exclude.
|
||||
|
||||
Returns:
|
||||
The mapping of items to colors.
|
||||
"""
|
||||
colors = list(_TEXT_COLOR_MAPPING.keys())
|
||||
if excluded_colors is not None:
|
||||
colors = [c for c in colors if c not in excluded_colors]
|
||||
@@ -23,20 +31,45 @@ def get_color_mapping(
|
||||
|
||||
|
||||
def get_colored_text(text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
"""Get colored text.
|
||||
|
||||
Args:
|
||||
text: The text to color.
|
||||
color: The color to use.
|
||||
|
||||
Returns:
|
||||
The colored text.
|
||||
"""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
|
||||
|
||||
def get_bolded_text(text: str) -> str:
|
||||
"""Get bolded text."""
|
||||
"""Get bolded text.
|
||||
|
||||
Args:
|
||||
text: The text to bold.
|
||||
|
||||
Returns:
|
||||
The bolded text.
|
||||
"""
|
||||
return f"\033[1m{text}\033[0m"
|
||||
|
||||
|
||||
def print_text(
|
||||
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
|
||||
) -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
"""Print text with highlighting and no end characters.
|
||||
|
||||
If a color is provided, the text will be printed in that color.
|
||||
If a file is provided, the text will be written to that file.
|
||||
|
||||
Args:
|
||||
text: The text to print.
|
||||
color: The color to use. Defaults to None.
|
||||
end: The end character to use. Defaults to "".
|
||||
file: The file to write to. Defaults to None.
|
||||
"""
|
||||
text_to_print = get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end, file=file)
|
||||
if file:
|
||||
|
||||
@@ -22,7 +22,7 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
class NoLock:
|
||||
"""Dummy lock that provides the proper interface but no protection"""
|
||||
"""Dummy lock that provides the proper interface but no protection."""
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
@@ -39,7 +39,21 @@ def tee_peer(
|
||||
peers: List[Deque[T]],
|
||||
lock: ContextManager[Any],
|
||||
) -> Generator[T, None, None]:
|
||||
"""An individual iterator of a :py:func:`~.tee`"""
|
||||
"""An individual iterator of a :py:func:`~.tee`.
|
||||
|
||||
This function is a generator that yields items from the shared iterator
|
||||
``iterator``. It buffers items until the least advanced iterator has
|
||||
yielded them as well. The buffer is shared with all other peers.
|
||||
|
||||
Args:
|
||||
iterator: The shared iterator.
|
||||
buffer: The buffer for this peer.
|
||||
peers: The buffers of all peers.
|
||||
lock: The lock to synchronise access to the shared buffers.
|
||||
|
||||
Yields:
|
||||
The next item from the shared iterator.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
if not buffer:
|
||||
@@ -118,6 +132,14 @@ class Tee(Generic[T]):
|
||||
*,
|
||||
lock: Optional[ContextManager[Any]] = None,
|
||||
):
|
||||
"""Create a new ``tee``.
|
||||
|
||||
Args:
|
||||
iterable: The iterable to split.
|
||||
n: The number of iterators to create. Defaults to 2.
|
||||
lock: The lock to synchronise access to the shared buffers.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._iterator = iter(iterable)
|
||||
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
|
||||
self._children = tuple(
|
||||
@@ -170,8 +192,8 @@ def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T
|
||||
size: The size of the batch. If None, returns a single batch.
|
||||
iterable: The iterable to batch.
|
||||
|
||||
Returns:
|
||||
An iterator over the batches.
|
||||
Yields:
|
||||
The batches of the iterable.
|
||||
"""
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
|
||||
@@ -124,8 +124,7 @@ _json_markdown_re = re.compile(r"```(json)?(.*)", re.DOTALL)
|
||||
def parse_json_markdown(
|
||||
json_string: str, *, parser: Callable[[str], Any] = parse_partial_json
|
||||
) -> dict:
|
||||
"""
|
||||
Parse a JSON string from a Markdown string.
|
||||
"""Parse a JSON string from a Markdown string.
|
||||
|
||||
Args:
|
||||
json_string: The Markdown string.
|
||||
@@ -175,6 +174,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
|
||||
Raises:
|
||||
OutputParserException: If the JSON string is invalid or does not contain
|
||||
the expected keys.
|
||||
"""
|
||||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
|
||||
@@ -90,7 +90,16 @@ def dereference_refs(
|
||||
full_schema: Optional[dict] = None,
|
||||
skip_keys: Optional[Sequence[str]] = None,
|
||||
) -> dict:
|
||||
"""Try to substitute $refs in JSON Schema."""
|
||||
"""Try to substitute $refs in JSON Schema.
|
||||
|
||||
Args:
|
||||
schema_obj: The schema object to dereference.
|
||||
full_schema: The full schema object. Defaults to None.
|
||||
skip_keys: The keys to skip. Defaults to None.
|
||||
|
||||
Returns:
|
||||
The dereferenced schema object.
|
||||
"""
|
||||
|
||||
full_schema = full_schema or schema_obj
|
||||
skip_keys = (
|
||||
|
||||
@@ -42,7 +42,15 @@ class ChevronError(SyntaxError):
|
||||
|
||||
|
||||
def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
|
||||
"""Parse a literal from the template."""
|
||||
"""Parse a literal from the template.
|
||||
|
||||
Args:
|
||||
template: The template to parse.
|
||||
l_del: The left delimiter.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: The literal and the template.
|
||||
"""
|
||||
|
||||
global _CURRENT_LINE
|
||||
|
||||
@@ -59,7 +67,16 @@ def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
|
||||
|
||||
|
||||
def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
|
||||
"""Do a preliminary check to see if a tag could be a standalone."""
|
||||
"""Do a preliminary check to see if a tag could be a standalone.
|
||||
|
||||
Args:
|
||||
template: The template. (Not used.)
|
||||
literal: The literal.
|
||||
is_standalone: Whether the tag is standalone.
|
||||
|
||||
Returns:
|
||||
bool: Whether the tag could be a standalone.
|
||||
"""
|
||||
|
||||
# If there is a newline, or the previous tag was a standalone
|
||||
if literal.find("\n") != -1 or is_standalone:
|
||||
@@ -77,7 +94,16 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
|
||||
|
||||
|
||||
def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
|
||||
"""Do a final check to see if a tag could be a standalone."""
|
||||
"""Do a final check to see if a tag could be a standalone.
|
||||
|
||||
Args:
|
||||
template: The template.
|
||||
tag_type: The type of the tag.
|
||||
is_standalone: Whether the tag is standalone.
|
||||
|
||||
Returns:
|
||||
bool: Whether the tag could be a standalone.
|
||||
"""
|
||||
|
||||
# Check right side if we might be a standalone
|
||||
if is_standalone and tag_type not in ["variable", "no escape"]:
|
||||
@@ -95,7 +121,20 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
|
||||
|
||||
|
||||
def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]:
|
||||
"""Parse a tag from a template."""
|
||||
"""Parse a tag from a template.
|
||||
|
||||
Args:
|
||||
template: The template.
|
||||
l_del: The left delimiter.
|
||||
r_del: The right delimiter.
|
||||
|
||||
Returns:
|
||||
Tuple[Tuple[str, str], str]: The tag and the template.
|
||||
|
||||
Raises:
|
||||
ChevronError: If the tag is unclosed.
|
||||
ChevronError: If the set delimiter tag is unclosed.
|
||||
"""
|
||||
global _CURRENT_LINE
|
||||
global _LAST_TAG_LINE
|
||||
|
||||
@@ -404,36 +443,36 @@ def render(
|
||||
|
||||
Arguments:
|
||||
|
||||
template -- A file-like object or a string containing the template
|
||||
template -- A file-like object or a string containing the template.
|
||||
|
||||
data -- A python dictionary with your data scope
|
||||
data -- A python dictionary with your data scope.
|
||||
|
||||
partials_path -- The path to where your partials are stored
|
||||
partials_path -- The path to where your partials are stored.
|
||||
If set to None, then partials won't be loaded from the file system
|
||||
(defaults to '.')
|
||||
(defaults to '.').
|
||||
|
||||
partials_ext -- The extension that you want the parser to look for
|
||||
(defaults to 'mustache')
|
||||
(defaults to 'mustache').
|
||||
|
||||
partials_dict -- A python dictionary which will be search for partials
|
||||
before the filesystem is. {'include': 'foo'} is the same
|
||||
as a file called include.mustache
|
||||
(defaults to {})
|
||||
(defaults to {}).
|
||||
|
||||
padding -- This is for padding partials, and shouldn't be used
|
||||
(but can be if you really want to)
|
||||
(but can be if you really want to).
|
||||
|
||||
def_ldel -- The default left delimiter
|
||||
("{{" by default, as in spec compliant mustache)
|
||||
("{{" by default, as in spec compliant mustache).
|
||||
|
||||
def_rdel -- The default right delimiter
|
||||
("}}" by default, as in spec compliant mustache)
|
||||
("}}" by default, as in spec compliant mustache).
|
||||
|
||||
scopes -- The list of scopes that get_key will look through
|
||||
scopes -- The list of scopes that get_key will look through.
|
||||
|
||||
warn -- Log a warning when a template substitution isn't found in the data
|
||||
|
||||
keep -- Keep unreplaced tags when a substitution isn't found in the data
|
||||
keep -- Keep unreplaced tags when a substitution isn't found in the data.
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -21,12 +21,27 @@ PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
|
||||
|
||||
# How to type hint this?
|
||||
def pre_init(func: Callable) -> Any:
|
||||
"""Decorator to run a function before model initialization."""
|
||||
"""Decorator to run a function before model initialization.
|
||||
|
||||
Args:
|
||||
func (Callable): The function to run before model initialization.
|
||||
|
||||
Returns:
|
||||
Any: The decorated function.
|
||||
"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
@wraps(func)
|
||||
def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Decorator to run a function before model initialization."""
|
||||
"""Decorator to run a function before model initialization.
|
||||
|
||||
Args:
|
||||
cls (Type[BaseModel]): The model class.
|
||||
values (Dict[str, Any]): The values to initialize the model with.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The values to initialize the model with.
|
||||
"""
|
||||
# Insert default values
|
||||
fields = cls.__fields__
|
||||
for name, field_info in fields.items():
|
||||
|
||||
@@ -36,5 +36,12 @@ def stringify_dict(data: dict) -> str:
|
||||
|
||||
|
||||
def comma_list(items: List[Any]) -> str:
|
||||
"""Convert a list to a comma-separated string."""
|
||||
"""Convert a list to a comma-separated string.
|
||||
|
||||
Args:
|
||||
items: The list to convert.
|
||||
|
||||
Returns:
|
||||
str: The comma-separated string.
|
||||
"""
|
||||
return ", ".join(str(item) for item in items)
|
||||
|
||||
@@ -15,7 +15,18 @@ from langchain_core.pydantic_v1 import SecretStr
|
||||
|
||||
|
||||
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
||||
"""Validate specified keyword args are mutually exclusive."""
|
||||
"""Validate specified keyword args are mutually exclusive."
|
||||
|
||||
Args:
|
||||
*arg_groups (Tuple[str, ...]): Groups of mutually exclusive keyword args.
|
||||
|
||||
Returns:
|
||||
Callable: Decorator that validates the specified keyword args
|
||||
are mutually exclusive
|
||||
|
||||
Raises:
|
||||
ValueError: If more than one arg in a group is defined.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
@@ -41,7 +52,14 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
||||
|
||||
|
||||
def raise_for_status_with_text(response: Response) -> None:
|
||||
"""Raise an error with the response text."""
|
||||
"""Raise an error with the response text.
|
||||
|
||||
Args:
|
||||
response (Response): The response to check for errors.
|
||||
|
||||
Raises:
|
||||
ValueError: If the response has an error status code.
|
||||
"""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
@@ -52,6 +70,12 @@ def raise_for_status_with_text(response: Response) -> None:
|
||||
def mock_now(dt_value): # type: ignore
|
||||
"""Context manager for mocking out datetime.now() in unit tests.
|
||||
|
||||
Args:
|
||||
dt_value: The datetime value to use for datetime.now().
|
||||
|
||||
Yields:
|
||||
datetime.datetime: The mocked datetime class.
|
||||
|
||||
Example:
|
||||
with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):
|
||||
assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11)
|
||||
@@ -86,7 +110,21 @@ def guard_import(
|
||||
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
|
||||
) -> Any:
|
||||
"""Dynamically import a module and raise an exception if the module is not
|
||||
installed."""
|
||||
installed.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to import.
|
||||
pip_name (str, optional): The name of the module to install with pip.
|
||||
Defaults to None.
|
||||
package (str, optional): The package to import the module from.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Any: The imported module.
|
||||
|
||||
Raises:
|
||||
ImportError: If the module is not installed.
|
||||
"""
|
||||
try:
|
||||
module = importlib.import_module(module_name, package)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
@@ -105,7 +143,22 @@ def check_package_version(
|
||||
gt_version: Optional[str] = None,
|
||||
gte_version: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Check the version of a package."""
|
||||
"""Check the version of a package.
|
||||
|
||||
Args:
|
||||
package (str): The name of the package.
|
||||
lt_version (str, optional): The version must be less than this.
|
||||
Defaults to None.
|
||||
lte_version (str, optional): The version must be less than or equal to this.
|
||||
Defaults to None.
|
||||
gt_version (str, optional): The version must be greater than this.
|
||||
Defaults to None.
|
||||
gte_version (str, optional): The version must be greater than or equal to this.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the package version does not meet the requirements.
|
||||
"""
|
||||
imported_version = parse(version(package))
|
||||
if lt_version is not None and imported_version >= parse(lt_version):
|
||||
raise ValueError(
|
||||
@@ -133,7 +186,11 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
|
||||
"""Get field names, including aliases, for a pydantic class.
|
||||
|
||||
Args:
|
||||
pydantic_cls: Pydantic class."""
|
||||
pydantic_cls: Pydantic class.
|
||||
|
||||
Returns:
|
||||
Set[str]: Field names.
|
||||
"""
|
||||
all_required_field_names = set()
|
||||
for field in pydantic_cls.__fields__.values():
|
||||
all_required_field_names.add(field.name)
|
||||
@@ -153,6 +210,13 @@ def build_extra_kwargs(
|
||||
extra_kwargs: Extra kwargs passed in by user.
|
||||
values: Values passed in by user.
|
||||
all_required_field_names: All required field names for the pydantic class.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Extra kwargs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a field is specified in both values and extra_kwargs.
|
||||
ValueError: If a field is specified in model_kwargs.
|
||||
"""
|
||||
for field_name in list(values):
|
||||
if field_name in extra_kwargs:
|
||||
@@ -176,7 +240,14 @@ def build_extra_kwargs(
|
||||
|
||||
|
||||
def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
|
||||
"""Convert a string to a SecretStr if needed."""
|
||||
"""Convert a string to a SecretStr if needed.
|
||||
|
||||
Args:
|
||||
value (Union[SecretStr, str]): The value to convert.
|
||||
|
||||
Returns:
|
||||
SecretStr: The SecretStr value.
|
||||
"""
|
||||
if isinstance(value, SecretStr):
|
||||
return value
|
||||
return SecretStr(value)
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "langchain-core"
|
||||
version = "0.2.14"
|
||||
version = "0.2.19"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
@@ -38,11 +38,6 @@ python = ">=3.12.4"
|
||||
[tool.ruff.lint]
|
||||
select = [ "E", "F", "I", "T201",]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/unit_tests/prompts/test_chat.py" = ["E501"]
|
||||
"tests/unit_tests/runnables/test_runnable.py" = ["E501"]
|
||||
"tests/unit_tests/runnables/test_graph.py" = ["E501"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [ "tests/*",]
|
||||
|
||||
@@ -66,6 +61,11 @@ optional = true
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/unit_tests/prompts/test_chat.py" = [ "E501",]
|
||||
"tests/unit_tests/runnables/test_runnable.py" = [ "E501",]
|
||||
"tests/unit_tests/runnables/test_graph.py" = [ "E501",]
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.5"
|
||||
|
||||
@@ -90,12 +90,6 @@ pytest-asyncio = "^0.21.1"
|
||||
grandalf = "^0.8"
|
||||
pytest-profiling = "^1.7.0"
|
||||
responses = "^0.25.0"
|
||||
|
||||
[tool.poetry.group.test.dependencies.langchain-standard-tests]
|
||||
path = "../standard-tests"
|
||||
develop = true
|
||||
|
||||
|
||||
[[tool.poetry.group.test.dependencies.numpy]]
|
||||
version = "^1.24.0"
|
||||
python = "<3.12"
|
||||
@@ -109,3 +103,7 @@ python = ">=3.12"
|
||||
[tool.poetry.group.typing.dependencies.langchain-text-splitters]
|
||||
path = "../text-splitters"
|
||||
develop = true
|
||||
|
||||
[tool.poetry.group.test.dependencies.langchain-standard-tests]
|
||||
path = "../standard-tests"
|
||||
develop = true
|
||||
|
||||
@@ -2,25 +2,21 @@ from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from langchain_core.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||
from langchain_core.indexing import InMemoryRecordManager, aindex, index
|
||||
from langchain_core.indexing.api import _abatch, _HashedDocument
|
||||
from langchain_core.vectorstores import VST, VectorStore
|
||||
from langchain_core.vectorstores import InMemoryVectorStore, VectorStore
|
||||
|
||||
|
||||
class ToyLoader(BaseLoader):
|
||||
@@ -42,101 +38,6 @@ class ToyLoader(BaseLoader):
|
||||
yield document
|
||||
|
||||
|
||||
class InMemoryVectorStore(VectorStore):
|
||||
"""In-memory implementation of VectorStore using a dictionary."""
|
||||
|
||||
def __init__(self, permit_upserts: bool = False) -> None:
|
||||
"""Vector store interface for testing things in memory."""
|
||||
self.store: Dict[str, Document] = {}
|
||||
self.permit_upserts = permit_upserts
|
||||
|
||||
def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
|
||||
"""Delete the given documents from the store using their IDs."""
|
||||
if ids:
|
||||
for _id in ids:
|
||||
self.store.pop(_id, None)
|
||||
|
||||
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
|
||||
"""Delete the given documents from the store using their IDs."""
|
||||
if ids:
|
||||
for _id in ids:
|
||||
self.store.pop(_id, None)
|
||||
|
||||
def add_documents( # type: ignore
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add the given documents to the store (insert behavior)."""
|
||||
if ids and len(ids) != len(documents):
|
||||
raise ValueError(
|
||||
f"Expected {len(ids)} ids, got {len(documents)} documents."
|
||||
)
|
||||
|
||||
if not ids:
|
||||
raise NotImplementedError("This is not implemented yet.")
|
||||
|
||||
for _id, document in zip(ids, documents):
|
||||
if _id in self.store and not self.permit_upserts:
|
||||
raise ValueError(
|
||||
f"Document with uid {_id} already exists in the store."
|
||||
)
|
||||
self.store[_id] = document
|
||||
|
||||
return list(ids)
|
||||
|
||||
async def aadd_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
if ids and len(ids) != len(documents):
|
||||
raise ValueError(
|
||||
f"Expected {len(ids)} ids, got {len(documents)} documents."
|
||||
)
|
||||
|
||||
if not ids:
|
||||
raise NotImplementedError("This is not implemented yet.")
|
||||
|
||||
for _id, document in zip(ids, documents):
|
||||
if _id in self.store and not self.permit_upserts:
|
||||
raise ValueError(
|
||||
f"Document with uid {_id} already exists in the store."
|
||||
)
|
||||
self.store[_id] = document
|
||||
return list(ids)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[Dict[Any, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add the given texts to the store (insert behavior)."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[Dict[Any, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
"""Create a vector store from a list of texts."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Find the most similar documents to the given query."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def record_manager() -> InMemoryRecordManager:
|
||||
"""Timestamped set fixture."""
|
||||
@@ -156,13 +57,15 @@ async def arecord_manager() -> InMemoryRecordManager:
|
||||
@pytest.fixture
|
||||
def vector_store() -> InMemoryVectorStore:
|
||||
"""Vector store fixture."""
|
||||
return InMemoryVectorStore()
|
||||
embeddings = DeterministicFakeEmbedding(size=5)
|
||||
return InMemoryVectorStore(embeddings)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def upserting_vector_store() -> InMemoryVectorStore:
|
||||
"""Vector store fixture."""
|
||||
return InMemoryVectorStore(permit_upserts=True)
|
||||
embeddings = DeterministicFakeEmbedding(size=5)
|
||||
return InMemoryVectorStore(embeddings)
|
||||
|
||||
|
||||
def test_indexing_same_content(
|
||||
@@ -286,7 +189,7 @@ def test_index_simple_delete_full(
|
||||
|
||||
doc_texts = set(
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore
|
||||
vector_store.get_by_ids([uid])[0].page_content # type: ignore
|
||||
for uid in vector_store.store
|
||||
)
|
||||
assert doc_texts == {"mutated document 1", "This is another document."}
|
||||
@@ -368,7 +271,7 @@ async def test_aindex_simple_delete_full(
|
||||
|
||||
doc_texts = set(
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore
|
||||
vector_store.get_by_ids([uid])[0].page_content # type: ignore
|
||||
for uid in vector_store.store
|
||||
)
|
||||
assert doc_texts == {"mutated document 1", "This is another document."}
|
||||
@@ -659,7 +562,7 @@ def test_incremental_delete(
|
||||
|
||||
doc_texts = set(
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore
|
||||
vector_store.get_by_ids([uid])[0].page_content # type: ignore
|
||||
for uid in vector_store.store
|
||||
)
|
||||
assert doc_texts == {"This is another document.", "This is a test document."}
|
||||
@@ -718,7 +621,7 @@ def test_incremental_delete(
|
||||
|
||||
doc_texts = set(
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore
|
||||
vector_store.get_by_ids([uid])[0].page_content # type: ignore
|
||||
for uid in vector_store.store
|
||||
)
|
||||
assert doc_texts == {
|
||||
@@ -786,7 +689,7 @@ def test_incremental_indexing_with_batch_size(
|
||||
|
||||
doc_texts = set(
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore
|
||||
vector_store.get_by_ids([uid])[0].page_content # type: ignore
|
||||
for uid in vector_store.store
|
||||
)
|
||||
assert doc_texts == {"1", "2", "3", "4"}
|
||||
@@ -836,7 +739,7 @@ def test_incremental_delete_with_batch_size(
|
||||
|
||||
doc_texts = set(
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore
|
||||
vector_store.get_by_ids([uid])[0].page_content # type: ignore
|
||||
for uid in vector_store.store
|
||||
)
|
||||
assert doc_texts == {"1", "2", "3", "4"}
|
||||
@@ -981,7 +884,7 @@ async def test_aincremental_delete(
|
||||
|
||||
doc_texts = set(
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore
|
||||
vector_store.get_by_ids([uid])[0].page_content # type: ignore
|
||||
for uid in vector_store.store
|
||||
)
|
||||
assert doc_texts == {"This is another document.", "This is a test document."}
|
||||
@@ -1040,7 +943,7 @@ async def test_aincremental_delete(
|
||||
|
||||
doc_texts = set(
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore
|
||||
vector_store.get_by_ids([uid])[0].page_content # type: ignore
|
||||
for uid in vector_store.store
|
||||
)
|
||||
assert doc_texts == {
|
||||
@@ -1232,8 +1135,10 @@ def test_deduplication_v2(
|
||||
|
||||
# using in memory implementation here
|
||||
assert isinstance(vector_store, InMemoryVectorStore)
|
||||
|
||||
ids = list(vector_store.store.keys())
|
||||
contents = sorted(
|
||||
[document.page_content for document in vector_store.store.values()]
|
||||
[document.page_content for document in vector_store.get_by_ids(ids)]
|
||||
)
|
||||
assert contents == ["1", "2", "3"]
|
||||
|
||||
@@ -1370,11 +1275,19 @@ def test_indexing_custom_batch_size(
|
||||
ids = [_HashedDocument.from_document(doc).uid for doc in docs]
|
||||
|
||||
batch_size = 1
|
||||
with patch.object(vector_store, "add_documents") as mock_add_documents:
|
||||
|
||||
original = vector_store.add_documents
|
||||
|
||||
try:
|
||||
mock_add_documents = MagicMock()
|
||||
vector_store.add_documents = mock_add_documents # type: ignore
|
||||
|
||||
index(docs, record_manager, vector_store, batch_size=batch_size)
|
||||
args, kwargs = mock_add_documents.call_args
|
||||
assert args == (docs,)
|
||||
assert kwargs == {"ids": ids, "batch_size": batch_size}
|
||||
finally:
|
||||
vector_store.add_documents = original # type: ignore
|
||||
|
||||
|
||||
async def test_aindexing_custom_batch_size(
|
||||
@@ -1390,8 +1303,9 @@ async def test_aindexing_custom_batch_size(
|
||||
ids = [_HashedDocument.from_document(doc).uid for doc in docs]
|
||||
|
||||
batch_size = 1
|
||||
with patch.object(vector_store, "aadd_documents") as mock_add_documents:
|
||||
await aindex(docs, arecord_manager, vector_store, batch_size=batch_size)
|
||||
args, kwargs = mock_add_documents.call_args
|
||||
assert args == (docs,)
|
||||
assert kwargs == {"ids": ids, "batch_size": batch_size}
|
||||
mock_add_documents = AsyncMock()
|
||||
vector_store.aadd_documents = mock_add_documents # type: ignore
|
||||
await aindex(docs, arecord_manager, vector_store, batch_size=batch_size)
|
||||
args, kwargs = mock_add_documents.call_args
|
||||
assert args == (docs,)
|
||||
assert kwargs == {"ids": ids, "batch_size": batch_size}
|
||||
|
||||
@@ -279,7 +279,7 @@ class CustomChat(GenericFakeChatModel):
|
||||
async def test_can_swap_caches() -> None:
|
||||
"""Test that we can use a different cache object.
|
||||
|
||||
This test verifies that when we fetch teh llm_string representation
|
||||
This test verifies that when we fetch the llm_string representation
|
||||
of the chat model, we can swap the cache object and still get the same
|
||||
result.
|
||||
"""
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
InvalidToolCall,
|
||||
ToolCall,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
|
||||
|
||||
def test_serdes_message() -> None:
|
||||
msg = AIMessage(
|
||||
content=[{"text": "blah", "type": "text"}],
|
||||
tool_calls=[ToolCall(name="foo", args={"bar": 1}, id="baz")],
|
||||
tool_calls=[create_tool_call(name="foo", args={"bar": 1}, id="baz")],
|
||||
invalid_tool_calls=[
|
||||
InvalidToolCall(name="foobad", args="blah", id="booz", error="bad")
|
||||
create_invalid_tool_call(name="foobad", args="blah", id="booz", error="bad")
|
||||
],
|
||||
)
|
||||
expected = {
|
||||
@@ -23,9 +20,17 @@ def test_serdes_message() -> None:
|
||||
"kwargs": {
|
||||
"type": "ai",
|
||||
"content": [{"text": "blah", "type": "text"}],
|
||||
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
|
||||
"tool_calls": [
|
||||
{"name": "foo", "args": {"bar": 1}, "id": "baz", "type": "tool_call"}
|
||||
],
|
||||
"invalid_tool_calls": [
|
||||
{"name": "foobad", "args": "blah", "id": "booz", "error": "bad"}
|
||||
{
|
||||
"name": "foobad",
|
||||
"args": "blah",
|
||||
"id": "booz",
|
||||
"error": "bad",
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
@@ -38,8 +43,13 @@ def test_serdes_message_chunk() -> None:
|
||||
chunk = AIMessageChunk(
|
||||
content=[{"text": "blah", "type": "text"}],
|
||||
tool_call_chunks=[
|
||||
ToolCallChunk(name="foo", args='{"bar": 1}', id="baz", index=0),
|
||||
ToolCallChunk(name="foobad", args="blah", id="booz", index=1),
|
||||
create_tool_call_chunk(name="foo", args='{"bar": 1}', id="baz", index=0),
|
||||
create_tool_call_chunk(
|
||||
name="foobad",
|
||||
args="blah",
|
||||
id="booz",
|
||||
index=1,
|
||||
),
|
||||
],
|
||||
)
|
||||
expected = {
|
||||
@@ -49,18 +59,33 @@ def test_serdes_message_chunk() -> None:
|
||||
"kwargs": {
|
||||
"type": "AIMessageChunk",
|
||||
"content": [{"text": "blah", "type": "text"}],
|
||||
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
|
||||
"tool_calls": [
|
||||
{"name": "foo", "args": {"bar": 1}, "id": "baz", "type": "tool_call"}
|
||||
],
|
||||
"invalid_tool_calls": [
|
||||
{
|
||||
"name": "foobad",
|
||||
"args": "blah",
|
||||
"id": "booz",
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
"tool_call_chunks": [
|
||||
{"name": "foo", "args": '{"bar": 1}', "id": "baz", "index": 0},
|
||||
{"name": "foobad", "args": "blah", "id": "booz", "index": 1},
|
||||
{
|
||||
"name": "foo",
|
||||
"args": '{"bar": 1}',
|
||||
"id": "baz",
|
||||
"index": 0,
|
||||
"type": "tool_call_chunk",
|
||||
},
|
||||
{
|
||||
"name": "foobad",
|
||||
"args": "blah",
|
||||
"id": "booz",
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user