Compare commits

..

70 Commits

Author SHA1 Message Date
Bagatur
9d0c1d2dc9 docs: specify init_chat_model version (#24274) 2024-07-15 16:29:06 +00:00
MoraxMa
a7296bddc2 docs: updated Tongyi package (#24259)
* updated pip install package
2024-07-15 16:25:35 +00:00
Bagatur
c9473367b1 langchain[patch]: Release 0.2.8 (#24273) 2024-07-15 16:05:51 +00:00
JP-Ellis
f77659463a core[patch]: allow message utils to work with lcel (#23743)
The functions `convert_to_messages` has had an expansion of the
arguments it can take:

1. Previously, it only could take a `Sequence` in order to iterate over
it. This has been broadened slightly to an `Iterable` (which should have
no other impact).
2. Support for `PromptValue` and `BaseChatPromptTemplate` has been
added. These are generated when combining messages using the overloaded
`+` operator.

Functions which rely on `convert_to_messages` (namely `filter_messages`,
`merge_message_runs` and `trim_messages`) have had the type of their
arguments similarly expanded.

Resolves #23706.

<!--
If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
-->

---------

Signed-off-by: JP-Ellis <josh@jpellis.me>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2024-07-15 08:58:05 -07:00
Harold Martin
ccdaf14eff docs: Spell check fixes (#24217)
**Description:** Spell check fixes for docs, comments, and a couple of
strings. No code change e.g. variable names.
**Issue:** none
**Dependencies:** none
**Twitter handle:** hmartin
2024-07-15 15:51:43 +00:00
Leonid Ganeline
cacdf96f9c core docstrings tracers update (#24211)
Added missed docstrings. Formatted docstrings to the consistent form.
2024-07-15 11:37:09 -04:00
Leonid Ganeline
36ee083753 core: docstrings utils update (#24213)
Added missed docstrings. Formatted docstrings to the consistent form.
2024-07-15 11:36:00 -04:00
thehunmonkgroup
e8a21146d3 community[patch]: upgrade default model for ChatAnyscale (#24232)
Old default `meta-llama/Llama-2-7b-chat-hf` no longer supported.
2024-07-15 11:34:59 -04:00
Bagatur
a0958c0607 docs: more tool call -> tool message docs (#24271) 2024-07-15 07:55:07 -07:00
Bagatur
620b118c70 core[patch]: Release 0.2.19 (#24272) 2024-07-15 07:51:30 -07:00
ccurme
888fbc07b5 core[patch]: support passing args_schema through as_tool (#24269)
Note: this allows the schema to be passed in positionally.

```python
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableLambda


class Add(BaseModel):
    """Add two integers together."""

    a: int = Field(..., description="First integer")
    b: int = Field(..., description="Second integer")


def add(input: dict) -> int:
    return input["a"] + input["b"]


runnable = RunnableLambda(add)
as_tool = runnable.as_tool(Add)
as_tool.args_schema.schema()
```
```
{'title': 'Add',
 'description': 'Add two integers together.',
 'type': 'object',
 'properties': {'a': {'title': 'A',
   'description': 'First integer',
   'type': 'integer'},
  'b': {'title': 'B', 'description': 'Second integer', 'type': 'integer'}},
 'required': ['a', 'b']}
```
2024-07-15 07:51:05 -07:00
ccurme
ab2d7821a7 fireworks[patch]: use firefunction-v2 in standard tests (#24264) 2024-07-15 13:15:08 +00:00
ccurme
6fc7610b1c standard-tests[patch]: update test_bind_runnables_as_tools (#24241)
Reduce number of tool arguments from two to one.
2024-07-15 08:35:07 -04:00
Bagatur
0da5078cad langchain[minor]: Generic configurable model (#23419)
alternative to
[23244](https://github.com/langchain-ai/langchain/pull/23244). allows
you to use chat model declarative methods

![Screenshot 2024-06-25 at 1 07 10
PM](https://github.com/langchain-ai/langchain/assets/22008038/910d1694-9b7b-46bc-bc2e-3792df9321d6)
2024-07-15 01:11:01 +00:00
Bagatur
d0728b0ba0 core[patch]: add tool name to tool message (#24243)
Copying current ToolNode behavior
2024-07-15 00:42:40 +00:00
Bagatur
9224027e45 docs: tool artifacts how to (#24198) 2024-07-14 17:04:47 -07:00
Bagatur
5c3e2612da core[patch]: Release 0.2.18 (#24230) 2024-07-13 09:14:43 -07:00
Bagatur
65321bf975 core[patch]: fix ToolCall "type" when streaming (#24218) 2024-07-13 08:59:03 -07:00
Jacob Lee
2b7d1cdd2f docs[patch]: Update tool child run docs (#24160)
Documents #24143
2024-07-13 07:52:37 -07:00
Anush
a653b209ba qdrant: test new QdrantVectorStore (#24165)
## Description

This PR adds integration tests to follow up on #24164.

By default, the tests use an in-memory instance.

To run the full suite of tests, with both in-memory and Qdrant server:

```
$ docker run -p 6333:6333 qdrant/qdrant

$ make test

$ make integration_test
```

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-07-12 23:59:30 +00:00
Roman Solomatin
f071581aea openai[patch]: update openai params (#23691)
**Description:** Explicitly add parameters from openai API



- [X] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-07-12 16:53:33 -07:00
Leonid Ganeline
f0a7581b50 milvus: docstring (#23151)
Added missed docstrings. Format docstrings to the consistent format
(used in the API Reference)

---------

Co-authored-by: Isaac Francisco <78627776+isahers1@users.noreply.github.com>
Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
Co-authored-by: Erick Friis <erick@langchain.dev>
2024-07-12 23:25:31 +00:00
Christian D. Glissov
474b88326f langchain_qdrant: Added method "_asimilarity_search_with_relevance_scores" to Qdrant class (#23954)
I stumbled upon a bug that led to different similarity scores between
the async and sync similarity searches with relevance scores in Qdrant.
The reason being is that _asimilarity_search_with_relevance_scores is
missing, this makes langchain_qdrant use the method of the vectorstore
baseclass leading to drastically different results.

To illustrate the magnitude here are the results running an identical
search in a test vectorstore.

Output of asimilarity_search_with_relevance_scores:
[0.9902903374601824, 0.9472135924938804, 0.8535534011299859]

Output of similarity_search_with_relevance_scores:
[0.9805806749203648, 0.8944271849877607, 0.7071068022599718]

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-07-12 23:25:20 +00:00
Bagatur
bdc03997c9 standard-tests[patch]: check for ToolCall["type"] (#24209) 2024-07-12 16:17:34 -07:00
Nada Amin
3f1cf00d97 docs: Improve neo4j semantic templates (#23939)
I made some changes based on the issues I stumbled on while following
the README of neo4j-semantic-ollama.
I made the changes to the ollama variant, and can also port the relevant
ones to the layer variant once this is approved.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-07-12 23:08:25 +00:00
Nada Amin
6b47c7361e docs: fix code usage to use the ollama variant (#23937)
**Description:** the template neo4j-semantic-ollama uses an import from
the neo4j-semantic-layer template instead of its own.

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-07-12 23:07:42 +00:00
Anirudh31415926535
7677ceea60 docs: model parameter mandatory for cohere embedding and rerank (#23349)
Latest langchain-cohere sdk mandates passing in the model parameter into
the Embeddings and Reranker inits.

This PR is to update the docs to reflect these changes.
2024-07-12 23:07:28 +00:00
Miroslav
aee55eda39 community: Skip Login to HuggubgFaceHub when token is not set (#21561)
Thank you for contributing to LangChain!

- [ ] **HuggingFaceEndpoint**: "Skip Login to HuggingFaceHub"
  - Where:  langchain, community, llm, huggingface_endpoint
 


- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
- **Description:** Skip login to huggingface hub when when
`huggingfacehub_api_token` is not set. This is needed when using custom
`endpoint_url` outside of HuggingFaceHub.
- **Issue:** the issue # it fixes
https://github.com/langchain-ai/langchain/issues/20342 and
https://github.com/langchain-ai/langchain/issues/19685
    - **Dependencies:** None


- [ ] **Add tests and docs**: 
  1. Tested with locally available TGI endpoint
  2.  Example Usage
```python
from langchain_community.llms import HuggingFaceEndpoint

llm = HuggingFaceEndpoint(
    endpoint_url='http://localhost:8080',
    server_kwargs={
        "headers": {"Content-Type": "application/json"}
    }
)
resp = llm.invoke("Tell me a joke")
print(resp)
```
 Also tested against HF Endpoints
 ```python
 from langchain_community.llms import HuggingFaceEndpoint
huggingfacehub_api_token = "hf_xyz"
repo_id = "mistralai/Mistral-7B-Instruct-v0.2"
llm = HuggingFaceEndpoint(
    huggingfacehub_api_token=huggingfacehub_api_token,
    repo_id=repo_id,
)
resp = llm.invoke("Tell me a joke")
print(resp)
 ```
Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-07-12 22:10:32 +00:00
Anush
d09dda5a08 qdrant: Bump patch version (#24168)
# Description

To release a new version of `langchain-qdrant` after #24165 and #24166.
2024-07-12 14:48:50 -07:00
Bagatur
12950cc602 standard-tests[patch]: improve runnable tool description (#24210) 2024-07-12 21:33:56 +00:00
Erick Friis
e8ee781a42 ibm: move to external repo (#24208) 2024-07-12 21:14:24 +00:00
Bagatur
02e71cebed together[patch]: Release 0.1.4 (#24205) 2024-07-12 13:59:58 -07:00
Bagatur
259d4d2029 anthropic[patch]: Release 0.1.20 (#24204) 2024-07-12 13:59:15 -07:00
Bagatur
3aed74a6fc fireworks[patch]: Release 0.1.5 (#24203) 2024-07-12 13:58:58 -07:00
Bagatur
13b0d7ec8f openai[patch]: Release 0.1.16 (#24202) 2024-07-12 13:58:39 -07:00
Bagatur
71cd6e6feb groq[patch]: Release 0.1.7 (#24201) 2024-07-12 13:58:19 -07:00
Bagatur
99054e19eb mistralai[patch]: Release 0.1.10 (#24200) 2024-07-12 13:57:58 -07:00
Bagatur
7a1321e2f9 ibm[patch]: Release 0.1.10 (#24199) 2024-07-12 13:57:38 -07:00
Bagatur
cb5031f22f integrations[patch]: require core >=0.2.17 (#24207) 2024-07-12 20:54:01 +00:00
Nithish Raghunandanan
f1618ec540 couchbase: Add standard and semantic caches (#23607)
Thank you for contributing to LangChain!

**Description:** Add support for caching (standard + semantic) LLM
responses using Couchbase


- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Co-authored-by: Nithish Raghunandanan <nithishr@users.noreply.github.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
2024-07-12 20:30:03 +00:00
Eugene Yurtsev
8d82a0d483 core[patch]: Mark GraphVectorStore as beta (#24195)
* This PR marks graph vectorstore as beta
2024-07-12 14:28:06 -04:00
Bagatur
0a1e475a30 core[patch]: Release 0.2.17 (#24189) 2024-07-12 17:08:29 +00:00
Bagatur
6166ea67a8 core[minor]: rename ToolMessage.raw_output -> artifact (#24185) 2024-07-12 09:52:44 -07:00
Jean Nshuti
d77d9bfc00 community[patch]: update typo document content returned from semanticscholar (#24175)
Update "astract" -> abstract
2024-07-12 15:40:47 +00:00
Leonid Ganeline
aa3e3cfa40 core[patch]: docstrings runnables update (#24161)
Added missed docstrings. Formatted docstrings to the consistent form.
2024-07-12 11:27:06 -04:00
mumu
14ba1d4b45 docs: fix numeric errors in tools_chain.ipynb (#24169)
Description: Corrected several numeric errors in the
docs/docs/how_to/tools_chain.ipynb file to ensure the accuracy of the
documentation.
2024-07-12 11:26:26 -04:00
Ikko Eltociear Ashimine
18da9f5e59 docs: update custom_chat_model.ipynb (#24170)
characetrs -> characters
2024-07-12 06:48:22 -04:00
Tomaz Bratanic
d3a2b9fae0 Fix neo4j type error on missing constraint information (#24177)
If you use `refresh_schema=False`, then the metadata constraint doesn't
exist. ATM, we used default `None` in the constraint check, but then
`any` fails because it can't iterate over None value
2024-07-12 06:39:29 -04:00
Anush
7014d07cab qdrant: new Qdrant implementation (#24164) 2024-07-12 04:52:02 +02:00
Xander Dumaine
35784d1c33 langchain[minor]: add document_variable_name to create_stuff_documents_chain (#24083)
- **Description:** `StuffDocumentsChain` uses `LLMChain` which is
deprecated by langchain runnables. `create_stuff_documents_chain` is the
replacement, but needs support for `document_variable_name` to allow
multiple uses of the chain within a longer chain.
- **Issue:** none
- **Dependencies:** none
2024-07-12 02:31:46 +00:00
Eugene Yurtsev
8858846607 milvus[patch]: Fix Milvus vectorstore for newer versions of langchain-core (#24152)
Fix for: https://github.com/langchain-ai/langchain/issues/24116

This keeps the old behavior of add_documents and add_texts
2024-07-11 18:51:18 -07:00
thedavgar
ffe6ca986e community: Fix Bug in Azure Search Vectorstore search asyncronously (#24081)
Thank you for contributing to LangChain!

**Description**:
This PR fixes a bug described in the issue in #24064, when using the
AzureSearch Vectorstore with the asyncronous methods to do search which
is also the method used for the retriever. The proposed change includes
just change the access of the embedding as optional because is it not
used anywhere to retrieve documents. Actually, the syncronous methods of
retrieval do not use the embedding neither.

With this PR the code given by the user in the issue works.

```python
vectorstore = AzureSearch(
    azure_search_endpoint=os.getenv("AI_SEARCH_ENDPOINT_SECRET"),
    azure_search_key=os.getenv("AI_SEARCH_API_KEY"),
    index_name=os.getenv("AI_SEARCH_INDEX_NAME_SECRET"),
    fields=fields,
    embedding_function=encoder,
)

retriever = vectorstore.as_retriever(search_type="hybrid", k=2)

await vectorstore.avector_search("what is the capital of France")
await retriever.ainvoke("what is the capital of France")
```

**Issue**:
The Azure Search Vectorstore is not working when searching for documents
with asyncronous methods, as described in issue #24064

**Dependencies**:
There are no extra dependencies required for this change.

---------

Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
2024-07-11 18:32:19 -07:00
Anush
7790d67f94 qdrant: New sparse embeddings provider interface - PART 1 (#24015)
## Description

This PR introduces a new sparse embedding provider interface to work
with the new Qdrant implementation that will follow this PR.

Additionally, an implementation of this interface is provided with
https://github.com/qdrant/fastembed.

This PR will be followed by
https://github.com/Anush008/langchain/pull/3.
2024-07-11 17:07:25 -07:00
Erick Friis
1132fb801b core: release 0.2.16 (#24159) 2024-07-11 23:59:41 +00:00
Nuno Campos
1d37aa8403 core: Remove extra newline (#24157) 2024-07-11 23:55:36 +00:00
ccurme
cb95198398 standard-tests[patch]: add tests for runnables as tools and streaming usage metadata (#24153) 2024-07-11 18:30:05 -04:00
Erick Friis
d002fa902f infra: fix redundant matrix config (#24151) 2024-07-11 15:15:41 -07:00
Bagatur
8d100c58de core[patch]: Tool accept RunnableConfig (#24143)
Relies on #24038

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-07-11 22:13:17 +00:00
Bagatur
5fd1e67808 core[minor], integrations...[patch]: Support ToolCall as Tool input and ToolMessage as Tool output (#24038)
Changes:
- ToolCall, InvalidToolCall and ToolCallChunk can all accept a "type"
parameter now
- LLM integration packages add "type" to all the above
- Tool supports ToolCall inputs that have "type" specified
- Tool outputs ToolMessage when a ToolCall is passed as input
- Tools can separately specify ToolMessage.content and
ToolMessage.raw_output
- Tools emit events for validation errors (using on_tool_error and
on_tool_end)

Example:
```python
@tool("structured_api", response_format="content_and_raw_output")
def _mock_structured_tool_with_raw_output(
    arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> Tuple[str, dict]:
    """A Structured Tool"""
    return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}


def test_tool_call_input_tool_message_with_raw_output() -> None:
    tool_call: Dict = {
        "name": "structured_api",
        "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
        "id": "123",
        "type": "tool_call",
    }
    expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123")
    tool = _mock_structured_tool_with_raw_output
    actual = tool.invoke(tool_call)
    assert actual == expected

    tool_call.pop("type")
    with pytest.raises(ValidationError):
        tool.invoke(tool_call)

    actual_content = tool.invoke(tool_call["args"])
    assert actual_content == expected.content
```

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-07-11 14:54:02 -07:00
Bagatur
eeb996034b core[patch]: Release 0.2.15 (#24149) 2024-07-11 21:34:25 +00:00
Nuno Campos
03fba07d15 core[patch]: Update styles for mermaid graphs (#24147) 2024-07-11 14:19:36 -07:00
Jacob Lee
c481a2715d docs[patch]: Add structural example to style guide (#24133)
CC @nfcampos
2024-07-11 13:20:14 -07:00
ccurme
8ee8ca7c83 core[patch]: propagate parse_docstring to tool decorator (#24123)
Disabled by default.

```python
from langchain_core.tools import tool

@tool(parse_docstring=True)
def foo(bar: str, baz: int) -> str:
    """The foo.

    Args:
        bar: this is the bar
        baz: this is the baz
    """
    return bar


foo.args_schema.schema()
```
```json
{
  "title": "fooSchema",
  "description": "The foo.",
  "type": "object",
  "properties": {
    "bar": {
      "title": "Bar",
      "description": "this is the bar",
      "type": "string"
    },
    "baz": {
      "title": "Baz",
      "description": "this is the baz",
      "type": "integer"
    }
  },
  "required": [
    "bar",
    "baz"
  ]
}
```
2024-07-11 20:11:45 +00:00
Jacob Lee
4121d4151f docs[patch]: Fix typo (#24132)
CC @efriis
2024-07-11 20:10:48 +00:00
Erick Friis
bd18faa2a0 infra: add SQLAlchemy to min version testing (#23186)
preventing issues like #22546 

Notes:
- this will only affect release CI. We may want to consider adding
running unit tests with min versions to PR CI in some form
- because this only affects release CI, it could create annoying issues
releasing while I'm on vacation. Unless anyone feels strongly, I'll wait
to merge this til when I'm back
2024-07-11 20:09:57 +00:00
Jacob Lee
f1f1f75782 community[patch]: Make AzureML endpoint return AI messages for type assistant (#24085) 2024-07-11 21:45:30 +02:00
Eugene Yurtsev
4ba14adec6 core[patch]: Clean up indexing test code (#24139)
Refactor the code to use the existing InMemroyVectorStore.

This change is needed for another PR that moves some of the imports
around (and messes up the mock.patch in this file)
2024-07-11 18:54:46 +00:00
Atul R
457677c1b7 community: Fixes use of ImagePromptTemplate with Ollama (#24140)
Description: ImagePromptTemplate for Multimodal llms like llava when
using Ollama
Twitter handle: https://x.com/a7ulr

Details:

When using llava models / any ollama multimodal llms and passing images
in the prompt as urls, langchain breaks with this error.

```python
image_url_components = image_url.split(",")
                           ^^^^^^^^^^^^^^^^^^^^
AttributeError: 'dict' object has no attribute 'split'
```

From the looks of it, there was bug where the condition did check for a
`url` field in the variable but missed to actually assign it.

This PR fixes ImagePromptTemplate for Multimodal llms like llava when
using Ollama specifically.

@hwchase17
2024-07-11 11:31:48 -07:00
Matt
8327925ab7 community:support additional Azure Search Options (#24134)
- **Description:** Support additional kwargs options for the Azure
Search client (Described here
https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/README.md#configurations)
    - **Issue:** N/A
    - **Dependencies:** No additional Dependencies

---------
2024-07-11 18:22:36 +00:00
ccurme
122e80e04d core[patch]: add versionadded to as_tool (#24138) 2024-07-11 18:08:08 +00:00
202 changed files with 10163 additions and 5784 deletions

View File

@@ -9,6 +9,7 @@ MIN_VERSION_LIBS = [
"langchain-community",
"langchain",
"langchain-text-splitters",
"SQLAlchemy",
]

View File

@@ -21,14 +21,6 @@ jobs:
run:
working-directory: ${{ inputs.working-directory }}
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
- "3.12"
name: "poetry run pytest -m compile tests/integration_tests #${{ inputs.python-version }}"
steps:
- uses: actions/checkout@v4

View File

@@ -14,10 +14,6 @@ env:
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- "3.12"
name: "check doc imports #${{ inputs.python-version }}"
steps:
- uses: actions/checkout@v4

View File

@@ -78,7 +78,7 @@ def _load_module_members(module_path: str, namespace: str) -> ModuleMembers:
continue
if inspect.isclass(type_):
# The clasification of the class is used to select a template
# The type of the class is used to select a template
# for the object when rendering the documentation.
# See `templates` directory for defined templates.
# This is a hacky solution to distinguish between different

View File

@@ -821,7 +821,7 @@ We recommend this method as a starting point when working with structured output
- If multiple underlying techniques are supported, you can supply a `method` parameter to
[toggle which one is used](/docs/how_to/structured_output/#advanced-specifying-the-method-for-structuring-outputs).
You may want or need to use other techiniques if:
You may want or need to use other techniques if:
- The chat model you are using does not support tool calling.
- You are working with very complex schemas and the model is having trouble generating outputs that conform.

View File

@@ -33,6 +33,8 @@ Some examples include:
- [Build a Simple LLM Application with LCEL](/docs/tutorials/llm_chain/)
- [Build a Retrieval Augmented Generation (RAG) App](/docs/tutorials/rag/)
A good structural rule of thumb is to follow the structure of this [example from Numpy](https://numpy.org/numpy-tutorials/content/tutorial-svd.html).
Here are some high-level tips on writing a good tutorial:

View File

@@ -15,6 +15,12 @@
"\n",
"Make sure you have the integration packages installed for any model providers you want to support. E.g. you should have `langchain-openai` installed to init an OpenAI model.\n",
"\n",
":::\n",
"\n",
":::info Requires ``langchain >= 0.2.8``\n",
"\n",
"This functionality was added in ``langchain-core == 0.2.8``. Please make sure your package is up to date.\n",
"\n",
":::"
]
},
@@ -25,7 +31,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain langchain-openai langchain-anthropic langchain-google-vertexai"
"%pip install -qU langchain>=0.2.8 langchain-openai langchain-anthropic langchain-google-vertexai"
]
},
{
@@ -76,32 +82,6 @@
"print(\"Gemini 1.5: \" + gemini_15.invoke(\"what's your name\").content + \"\\n\")"
]
},
{
"cell_type": "markdown",
"id": "fff9a4c8-b6ee-4a1a-8d3d-0ecaa312d4ed",
"metadata": {},
"source": [
"## Simple config example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75c25d39-bf47-4b51-a6c6-64d9c572bfd6",
"metadata": {},
"outputs": [],
"source": [
"user_config = {\n",
" \"model\": \"...user-specified...\",\n",
" \"model_provider\": \"...user-specified...\",\n",
" \"temperature\": 0,\n",
" \"max_tokens\": 1000,\n",
"}\n",
"\n",
"llm = init_chat_model(**user_config)\n",
"llm.invoke(\"what's your name\")"
]
},
{
"cell_type": "markdown",
"id": "f811f219-5e78-4b62-b495-915d52a22532",
@@ -125,12 +105,215 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da07b5c0-d2e6-42e4-bfcd-2efcfaae6221",
"cell_type": "markdown",
"id": "476a44db-c50d-4846-951d-0f1c9ba8bbaa",
"metadata": {},
"outputs": [],
"source": []
"source": [
"## Creating a configurable model\n",
"\n",
"You can also create a runtime-configurable model by specifying `configurable_fields`. If you don't specify a `model` value, then \"model\" and \"model_provider\" be configurable by default."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6c037f27-12d7-4e83-811e-4245c0e3ba58",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"I'm an AI language model created by OpenAI, and I don't have a personal name. You can call me Assistant or any other name you prefer! How can I assist you today?\", response_metadata={'token_usage': {'completion_tokens': 37, 'prompt_tokens': 11, 'total_tokens': 48}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_d576307f90', 'finish_reason': 'stop', 'logprobs': None}, id='run-5428ab5c-b5c0-46de-9946-5d4ca40dbdc8-0', usage_metadata={'input_tokens': 11, 'output_tokens': 37, 'total_tokens': 48})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"configurable_model = init_chat_model(temperature=0)\n",
"\n",
"configurable_model.invoke(\n",
" \"what's your name\", config={\"configurable\": {\"model\": \"gpt-4o\"}}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "321e3036-abd2-4e1f-bcc6-606efd036954",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"My name is Claude. It's nice to meet you!\", response_metadata={'id': 'msg_012XvotUJ3kGLXJUWKBVxJUi', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 11, 'output_tokens': 15}}, id='run-1ad1eefe-f1c6-4244-8bc6-90e2cb7ee554-0', usage_metadata={'input_tokens': 11, 'output_tokens': 15, 'total_tokens': 26})"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"configurable_model.invoke(\n",
" \"what's your name\", config={\"configurable\": {\"model\": \"claude-3-5-sonnet-20240620\"}}\n",
")"
]
},
{
"cell_type": "markdown",
"id": "7f3b3d4a-4066-45e4-8297-ea81ac8e70b7",
"metadata": {},
"source": [
"### Configurable model with default values\n",
"\n",
"We can create a configurable model with default model values, specify which parameters are configurable, and add prefixes to configurable params:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "814a2289-d0db-401e-b555-d5116112b413",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"I'm an AI language model created by OpenAI, and I don't have a personal name. You can call me Assistant or any other name you prefer! How can I assist you today?\", response_metadata={'token_usage': {'completion_tokens': 37, 'prompt_tokens': 11, 'total_tokens': 48}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_ce0793330f', 'finish_reason': 'stop', 'logprobs': None}, id='run-3923e328-7715-4cd6-b215-98e4b6bf7c9d-0', usage_metadata={'input_tokens': 11, 'output_tokens': 37, 'total_tokens': 48})"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"first_llm = init_chat_model(\n",
" model=\"gpt-4o\",\n",
" temperature=0,\n",
" configurable_fields=(\"model\", \"model_provider\", \"temperature\", \"max_tokens\"),\n",
" config_prefix=\"first\", # useful when you have a chain with multiple models\n",
")\n",
"\n",
"first_llm.invoke(\"what's your name\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "6c8755ba-c001-4f5a-a497-be3f1db83244",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"My name is Claude. It's nice to meet you!\", response_metadata={'id': 'msg_01RyYR64DoMPNCfHeNnroMXm', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 11, 'output_tokens': 15}}, id='run-22446159-3723-43e6-88df-b84797e7751d-0', usage_metadata={'input_tokens': 11, 'output_tokens': 15, 'total_tokens': 26})"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"first_llm.invoke(\n",
" \"what's your name\",\n",
" config={\n",
" \"configurable\": {\n",
" \"first_model\": \"claude-3-5-sonnet-20240620\",\n",
" \"first_temperature\": 0.5,\n",
" \"first_max_tokens\": 100,\n",
" }\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0072b1a3-7e44-4b4e-8b07-efe1ba91a689",
"metadata": {},
"source": [
"### Using a configurable model declaratively\n",
"\n",
"We can call declarative operations like `bind_tools`, `with_structured_output`, `with_configurable`, etc. on a configurable model and chain a configurable model in the same way that we would a regularly instantiated chat model object."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "067dabee-1050-4110-ae24-c48eba01e13b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'GetPopulation',\n",
" 'args': {'location': 'Los Angeles, CA'},\n",
" 'id': 'call_sYT3PFMufHGWJD32Hi2CTNUP'},\n",
" {'name': 'GetPopulation',\n",
" 'args': {'location': 'New York, NY'},\n",
" 'id': 'call_j1qjhxRnD3ffQmRyqjlI1Lnk'}]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"class GetWeather(BaseModel):\n",
" \"\"\"Get the current weather in a given location\"\"\"\n",
"\n",
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
"\n",
"\n",
"class GetPopulation(BaseModel):\n",
" \"\"\"Get the current population in a given location\"\"\"\n",
"\n",
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
"\n",
"\n",
"llm = init_chat_model(temperature=0)\n",
"llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])\n",
"\n",
"llm_with_tools.invoke(\n",
" \"what's bigger in 2024 LA or NYC\", config={\"configurable\": {\"model\": \"gpt-4o\"}}\n",
").tool_calls"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e57dfe9f-cd24-4e37-9ce9-ccf8daf78f89",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'GetPopulation',\n",
" 'args': {'location': 'Los Angeles, CA'},\n",
" 'id': 'toolu_01CxEHxKtVbLBrvzFS7GQ5xR'},\n",
" {'name': 'GetPopulation',\n",
" 'args': {'location': 'New York City, NY'},\n",
" 'id': 'toolu_013A79qt5toWSsKunFBDZd5S'}]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm_with_tools.invoke(\n",
" \"what's bigger in 2024 LA or NYC\",\n",
" config={\"configurable\": {\"model\": \"claude-3-5-sonnet-20240620\"}},\n",
").tool_calls"
]
}
],
"metadata": {
@@ -149,7 +332,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.11.9"
}
},
"nbformat": 4,

View File

@@ -48,20 +48,10 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "40ed76a2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: You are using pip version 22.0.4; however, version 24.0 is available.\n",
"You should consider upgrading via the '/Users/jacoblee/.pyenv/versions/3.10.5/bin/python -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n",
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
]
}
],
"outputs": [],
"source": [
"%pip install --upgrade --quiet langchain langchain-openai\n",
"\n",

View File

@@ -180,7 +180,7 @@
"id": "32b1a992-8997-4c98-8eb2-c9fe9431b799",
"metadata": {},
"source": [
"Alternatively, we can add typing information via [Runnable.with_types](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.base.Runnable.html#langchain_core.runnables.base.Runnable.with_types):"
"Alternatively, the schema can be fully specified by directly passing the desired [args_schema](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html#langchain_core.tools.BaseTool.args_schema) for the tool:"
]
},
{
@@ -190,10 +190,18 @@
"metadata": {},
"outputs": [],
"source": [
"as_tool = runnable.with_types(input_type=Args).as_tool(\n",
" name=\"My tool\",\n",
" description=\"Explanation of when to use tool.\",\n",
")"
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"class GSchema(BaseModel):\n",
" \"\"\"Apply a function to an integer and list of integers.\"\"\"\n",
"\n",
" a: int = Field(..., description=\"Integer\")\n",
" b: List[int] = Field(..., description=\"List of ints\")\n",
"\n",
"\n",
"runnable = RunnableLambda(g)\n",
"as_tool = runnable.as_tool(GSchema)"
]
},
{

View File

@@ -131,7 +131,7 @@
"source": [
"## Base Chat Model\n",
"\n",
"Let's implement a chat model that echoes back the first `n` characetrs of the last message in the prompt!\n",
"Let's implement a chat model that echoes back the first `n` characters of the last message in the prompt!\n",
"\n",
"To do so, we will inherit from `BaseChatModel` and we'll need to implement the following:\n",
"\n",

View File

@@ -16,13 +16,15 @@
"| args_schema | Pydantic BaseModel | Optional but recommended, can be used to provide more information (e.g., few-shot examples) or validation for expected parameters |\n",
"| return_direct | boolean | Only relevant for agents. When True, after invoking the given tool, the agent will stop and return the result direcly to the user. |\n",
"\n",
"LangChain provides 3 ways to create tools:\n",
"LangChain supports the creation of tools from:\n",
"\n",
"1. Using [@tool decorator](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html#langchain_core.tools.tool) -- the simplest way to define a custom tool.\n",
"2. Using [StructuredTool.from_function](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.StructuredTool.html#langchain_core.tools.StructuredTool.from_function) class method -- this is similar to the `@tool` decorator, but allows more configuration and specification of both sync and async implementations.\n",
"1. Functions;\n",
"2. LangChain [Runnables](/docs/concepts#runnable-interface);\n",
"3. By sub-classing from [BaseTool](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.BaseTool.html) -- This is the most flexible method, it provides the largest degree of control, at the expense of more effort and code.\n",
"\n",
"The `@tool` or the `StructuredTool.from_function` class method should be sufficient for most use cases.\n",
"Creating tools from functions may be sufficient for most use cases, and can be done via a simple [@tool decorator](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html#langchain_core.tools.tool). If more configuration is needed-- e.g., specification of both sync and async implementations-- one can also use the [StructuredTool.from_function](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.StructuredTool.html#langchain_core.tools.StructuredTool.from_function) class method.\n",
"\n",
"In this guide we provide an overview of these methods.\n",
"\n",
":::{.callout-tip}\n",
"\n",
@@ -35,7 +37,9 @@
"id": "c7326b23",
"metadata": {},
"source": [
"## @tool decorator\n",
"## Creating tools from functions\n",
"\n",
"### @tool decorator\n",
"\n",
"This `@tool` decorator is the simplest way to define a custom tool. The decorator uses the function name as the tool name by default, but this can be overridden by passing a string as the first argument. Additionally, the decorator will use the function's docstring as the tool's description - so a docstring MUST be provided. "
]
@@ -51,7 +55,7 @@
"output_type": "stream",
"text": [
"multiply\n",
"multiply(a: int, b: int) -> int - Multiply two numbers.\n",
"Multiply two numbers.\n",
"{'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'integer'}}\n"
]
}
@@ -96,6 +100,57 @@
" return a * b"
]
},
{
"cell_type": "markdown",
"id": "8f0edc51-c586-414c-8941-c8abe779943f",
"metadata": {},
"source": [
"Note that `@tool` supports parsing of annotations, nested schemas, and other features:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5626423f-053e-4a66-adca-1d794d835397",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'title': 'multiply_by_maxSchema',\n",
" 'description': 'Multiply a by the maximum of b.',\n",
" 'type': 'object',\n",
" 'properties': {'a': {'title': 'A',\n",
" 'description': 'scale factor',\n",
" 'type': 'string'},\n",
" 'b': {'title': 'B',\n",
" 'description': 'list of ints over which to take maximum',\n",
" 'type': 'array',\n",
" 'items': {'type': 'integer'}}},\n",
" 'required': ['a', 'b']}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from typing import Annotated, List\n",
"\n",
"\n",
"@tool\n",
"def multiply_by_max(\n",
" a: Annotated[str, \"scale factor\"],\n",
" b: Annotated[List[int], \"list of ints over which to take maximum\"],\n",
") -> int:\n",
" \"\"\"Multiply a by the maximum of b.\"\"\"\n",
" return a * max(b)\n",
"\n",
"\n",
"multiply_by_max.args_schema.schema()"
]
},
{
"cell_type": "markdown",
"id": "98d6eee9",
@@ -106,7 +161,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "9216d03a-f6ea-4216-b7e1-0661823a4c0b",
"metadata": {},
"outputs": [
@@ -115,7 +170,7 @@
"output_type": "stream",
"text": [
"multiplication-tool\n",
"multiplication-tool(a: int, b: int) -> int - Multiply two numbers.\n",
"Multiply two numbers.\n",
"{'a': {'title': 'A', 'description': 'first number', 'type': 'integer'}, 'b': {'title': 'B', 'description': 'second number', 'type': 'integer'}}\n",
"True\n"
]
@@ -143,19 +198,84 @@
"print(multiply.return_direct)"
]
},
{
"cell_type": "markdown",
"id": "33a9e94d-0b60-48f3-a4c2-247dce096e66",
"metadata": {},
"source": [
"#### Docstring parsing"
]
},
{
"cell_type": "markdown",
"id": "6d0cb586-93d4-4ff1-9779-71df7853cb68",
"metadata": {},
"source": [
"`@tool` can optionally parse [Google Style docstrings](https://google.github.io/styleguide/pyguide.html#383-functions-and-methods) and associate the docstring components (such as arg descriptions) to the relevant parts of the tool schema. To toggle this behavior, specify `parse_docstring`:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "336f5538-956e-47d5-9bde-b732559f9e61",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'title': 'fooSchema',\n",
" 'description': 'The foo.',\n",
" 'type': 'object',\n",
" 'properties': {'bar': {'title': 'Bar',\n",
" 'description': 'The bar.',\n",
" 'type': 'string'},\n",
" 'baz': {'title': 'Baz', 'description': 'The baz.', 'type': 'integer'}},\n",
" 'required': ['bar', 'baz']}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@tool(parse_docstring=True)\n",
"def foo(bar: str, baz: int) -> str:\n",
" \"\"\"The foo.\n",
"\n",
" Args:\n",
" bar: The bar.\n",
" baz: The baz.\n",
" \"\"\"\n",
" return bar\n",
"\n",
"\n",
"foo.args_schema.schema()"
]
},
{
"cell_type": "markdown",
"id": "f18a2503-5393-421b-99fa-4a01dd824d0e",
"metadata": {},
"source": [
":::{.callout-caution}\n",
"By default, `@tool(parse_docstring=True)` will raise `ValueError` if the docstring does not parse correctly. See [API Reference](https://api.python.langchain.com/en/latest/tools/langchain_core.tools.tool.html) for detail and examples.\n",
":::"
]
},
{
"cell_type": "markdown",
"id": "b63fcc3b",
"metadata": {},
"source": [
"## StructuredTool\n",
"### StructuredTool\n",
"\n",
"The `StrurcturedTool.from_function` class method provides a bit more configurability than the `@tool` decorator, without requiring much additional code."
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "564fbe6f-11df-402d-b135-ef6ff25e1e63",
"metadata": {},
"outputs": [
@@ -198,7 +318,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"id": "6bc055d4-1fbe-4db5-8881-9c382eba6b1b",
"metadata": {},
"outputs": [
@@ -208,7 +328,7 @@
"text": [
"6\n",
"Calculator\n",
"Calculator(a: int, b: int) -> int - multiply numbers\n",
"multiply numbers\n",
"{'a': {'title': 'A', 'description': 'first number', 'type': 'integer'}, 'b': {'title': 'B', 'description': 'second number', 'type': 'integer'}}\n"
]
}
@@ -239,6 +359,63 @@
"print(calculator.args)"
]
},
{
"cell_type": "markdown",
"id": "5517995d-54e3-449b-8fdb-03561f5e4647",
"metadata": {},
"source": [
"## Creating tools from Runnables\n",
"\n",
"LangChain [Runnables](/docs/concepts#runnable-interface) that accept string or `dict` input can be converted to tools using the [as_tool](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.base.Runnable.html#langchain_core.runnables.base.Runnable.as_tool) method, which allows for the specification of names, descriptions, and additional schema information for arguments.\n",
"\n",
"Example usage:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8ef593c5-cf72-4c10-bfc9-7d21874a0c24",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'answer_style': {'title': 'Answer Style', 'type': 'string'}}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.language_models import GenericFakeChatModel\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [(\"human\", \"Hello. Please respond in the style of {answer_style}.\")]\n",
")\n",
"\n",
"# Placeholder LLM\n",
"llm = GenericFakeChatModel(messages=iter([\"hello matey\"]))\n",
"\n",
"chain = prompt | llm | StrOutputParser()\n",
"\n",
"as_tool = chain.as_tool(\n",
" name=\"Style responder\", description=\"Description of when to use tool.\"\n",
")\n",
"as_tool.args"
]
},
{
"cell_type": "markdown",
"id": "0521b787-a146-45a6-8ace-ae1ac4669dd7",
"metadata": {},
"source": [
"See [this guide](/docs/how_to/convert_runnable_to_tool) for more detail."
]
},
{
"cell_type": "markdown",
"id": "b840074b-9c10-4ca0-aed8-626c52b2398f",
@@ -251,7 +428,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 10,
"id": "1dad8f8e",
"metadata": {},
"outputs": [],
@@ -300,7 +477,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"id": "bb551c33",
"metadata": {},
"outputs": [
@@ -351,7 +528,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 12,
"id": "6615cb77-fd4c-4676-8965-f92cc71d4944",
"metadata": {},
"outputs": [
@@ -383,7 +560,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 13,
"id": "bb2af583-eadd-41f4-a645-bf8748bd3dcd",
"metadata": {},
"outputs": [
@@ -428,7 +605,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 14,
"id": "4ad0932c-8610-4278-8c57-f9218f654c8a",
"metadata": {},
"outputs": [
@@ -473,7 +650,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 15,
"id": "7094c0e8-6192-4870-a942-aad5b5ae48fd",
"metadata": {},
"outputs": [],
@@ -496,7 +673,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 16,
"id": "b4d22022-b105-4ccc-a15b-412cb9ea3097",
"metadata": {},
"outputs": [
@@ -506,7 +683,7 @@
"'Error: There is no city by the name of foobar.'"
]
},
"execution_count": 12,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -530,7 +707,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 17,
"id": "3fad1728-d367-4e1b-9b54-3172981271cf",
"metadata": {},
"outputs": [
@@ -540,7 +717,7 @@
"\"There is no such city, but it's probably above 0K there!\""
]
},
"execution_count": 13,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
@@ -564,7 +741,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 18,
"id": "ebfe7c1f-318d-4e58-99e1-f31e69473c46",
"metadata": {},
"outputs": [
@@ -574,7 +751,7 @@
"'The following errors occurred during tool execution: `Error: There is no city by the name of foobar.`'"
]
},
"execution_count": 14,
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
@@ -591,13 +768,189 @@
"\n",
"get_weather_tool.invoke({\"city\": \"foobar\"})"
]
},
{
"cell_type": "markdown",
"id": "1a8d8383-11b3-445e-956f-df4e96995e00",
"metadata": {},
"source": [
"## Returning artifacts of Tool execution\n",
"\n",
"Sometimes there are artifacts of a tool's execution that we want to make accessible to downstream components in our chain or agent, but that we don't want to expose to the model itself. For example if a tool returns custom objects like Documents, we may want to pass some view or metadata about this output to the model without passing the raw output to the model. At the same time, we may want to be able to access this full output elsewhere, for example in downstream tools.\n",
"\n",
"The Tool and [ToolMessage](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.tool.ToolMessage.html) interfaces make it possible to distinguish between the parts of the tool output meant for the model (this is the ToolMessage.content) and those parts which are meant for use outside the model (ToolMessage.artifact).\n",
"\n",
":::info Requires ``langchain-core >= 0.2.19``\n",
"\n",
"This functionality was added in ``langchain-core == 0.2.19``. Please make sure your package is up to date.\n",
"\n",
":::\n",
"\n",
"If we want our tool to distinguish between message content and other artifacts, we need to specify `response_format=\"content_and_artifact\"` when defining our tool and make sure that we return a tuple of (content, artifact):"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "14905425-0334-43a0-9de9-5bcf622ede0e",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"from typing import List, Tuple\n",
"\n",
"from langchain_core.tools import tool\n",
"\n",
"\n",
"@tool(response_format=\"content_and_artifact\")\n",
"def generate_random_ints(min: int, max: int, size: int) -> Tuple[str, List[int]]:\n",
" \"\"\"Generate size random ints in the range [min, max].\"\"\"\n",
" array = [random.randint(min, max) for _ in range(size)]\n",
" content = f\"Successfully generated array of {size} random ints in [{min}, {max}].\"\n",
" return content, array"
]
},
{
"cell_type": "markdown",
"id": "49f057a6-8938-43ea-8faf-ae41e797ceb8",
"metadata": {},
"source": [
"If we invoke our tool directly with the tool arguments, we'll get back just the content part of the output:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "0f2e1528-404b-46e6-b87c-f0957c4b9217",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Successfully generated array of 10 random ints in [0, 9].'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generate_random_ints.invoke({\"min\": 0, \"max\": 9, \"size\": 10})"
]
},
{
"cell_type": "markdown",
"id": "1e62ebba-1737-4b97-b61a-7313ade4e8c2",
"metadata": {},
"source": [
"If we invoke our tool with a ToolCall (like the ones generated by tool-calling models), we'll get back a ToolMessage that contains both the content and artifact generated by the Tool:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cc197777-26eb-46b3-a83b-c2ce116c6311",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ToolMessage(content='Successfully generated array of 10 random ints in [0, 9].', name='generate_random_ints', tool_call_id='123', artifact=[1, 4, 2, 5, 3, 9, 0, 4, 7, 7])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generate_random_ints.invoke(\n",
" {\n",
" \"name\": \"generate_random_ints\",\n",
" \"args\": {\"min\": 0, \"max\": 9, \"size\": 10},\n",
" \"id\": \"123\", # required\n",
" \"type\": \"tool_call\", # required\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "dfdc1040-bf25-4790-b4c3-59452db84e11",
"metadata": {},
"source": [
"We can do the same when subclassing BaseTool:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fe1a09d1-378b-4b91-bb5e-0697c3d7eb92",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.tools import BaseTool\n",
"\n",
"\n",
"class GenerateRandomFloats(BaseTool):\n",
" name: str = \"generate_random_floats\"\n",
" description: str = \"Generate size random floats in the range [min, max].\"\n",
" response_format: str = \"content_and_artifact\"\n",
"\n",
" ndigits: int = 2\n",
"\n",
" def _run(self, min: float, max: float, size: int) -> Tuple[str, List[float]]:\n",
" range_ = max - min\n",
" array = [\n",
" round(min + (range_ * random.random()), ndigits=self.ndigits)\n",
" for _ in range(size)\n",
" ]\n",
" content = f\"Generated {size} floats in [{min}, {max}], rounded to {self.ndigits} decimals.\"\n",
" return content, array\n",
"\n",
" # Optionally define an equivalent async method\n",
"\n",
" # async def _arun(self, min: float, max: float, size: int) -> Tuple[str, List[float]]:\n",
" # ..."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8c3d16f6-1c4a-48ab-b05a-38547c592e79",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ToolMessage(content='Generated 3 floats in [0.1, 3.3333], rounded to 4 decimals.', name='generate_random_floats', tool_call_id='123', artifact=[1.4277, 0.7578, 2.4871])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rand_gen = GenerateRandomFloats(ndigits=4)\n",
"\n",
"rand_gen.invoke(\n",
" {\n",
" \"name\": \"generate_random_floats\",\n",
" \"args\": {\"min\": 0.1, \"max\": 3.3333, \"size\": 3},\n",
" \"id\": \"123\",\n",
" \"type\": \"tool_call\",\n",
" }\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "poetry-venv-311",
"language": "python",
"name": "python3"
"name": "poetry-venv-311"
},
"language_info": {
"codemirror_mode": {
@@ -609,7 +962,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.9"
},
"vscode": {
"interpreter": {

View File

@@ -67,15 +67,16 @@ If you'd prefer not to set an environment variable you can pass the key in direc
```python
from langchain_cohere import CohereEmbeddings
embeddings_model = CohereEmbeddings(cohere_api_key="...")
embeddings_model = CohereEmbeddings(cohere_api_key="...", model='embed-english-v3.0')
```
Otherwise you can initialize without any params:
Otherwise you can initialize simply as shown below:
```python
from langchain_cohere import CohereEmbeddings
embeddings_model = CohereEmbeddings()
embeddings_model = CohereEmbeddings(model='embed-english-v3.0')
```
Do note that it is mandatory to pass the model parameter while initializing the CohereEmbeddings class.
</TabItem>
<TabItem value="huggingface" label="Hugging Face">

View File

@@ -84,7 +84,7 @@ These are the core building blocks you can use when building applications.
- [How to: use chat model to call tools](/docs/how_to/tool_calling)
- [How to: stream tool calls](/docs/how_to/tool_streaming)
- [How to: few shot prompt tool behavior](/docs/how_to/tools_few_shot)
- [How to: bind model-specific formated tools](/docs/how_to/tools_model_specific)
- [How to: bind model-specific formatted tools](/docs/how_to/tools_model_specific)
- [How to: force a specific tool call](/docs/how_to/tool_choice)
- [How to: init any model in one line](/docs/how_to/chat_models_universal_init/)
@@ -195,7 +195,9 @@ LangChain [Tools](/docs/concepts/#tools) contain a description of the tool (to p
- [How to: add a human in the loop to tool usage](/docs/how_to/tools_human)
- [How to: handle errors when calling tools](/docs/how_to/tools_error)
- [How to: disable parallel tool calling](/docs/how_to/tool_choice)
- [How to: stream events from within a tool](/docs/how_to/tool_stream_events)
- [How to: access the `RunnableConfig` object within a custom tool](/docs/how_to/tool_configure)
- [How to: stream events from child runs within a custom tool](/docs/how_to/tool_stream_events)
- [How to: return extra artifacts from a tool](/docs/how_to/tool_artifacts/)
### Multimodal

View File

@@ -63,6 +63,38 @@
"Notice that if the contents of one of the messages to merge is a list of content blocks then the merged message will have a list of content blocks. And if both messages to merge have string contents then those are concatenated with a newline character."
]
},
{
"cell_type": "markdown",
"id": "11f7e8d3",
"metadata": {},
"source": [
"The `merge_message_runs` utility also works with messages composed together using the overloaded `+` operation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b51855c5",
"metadata": {},
"outputs": [],
"source": [
"messages = (\n",
" SystemMessage(\"you're a good assistant.\")\n",
" + SystemMessage(\"you always respond with a joke.\")\n",
" + HumanMessage([{\"type\": \"text\", \"text\": \"i wonder why it's called langchain\"}])\n",
" + HumanMessage(\"and who is harrison chasing anyways\")\n",
" + AIMessage(\n",
" 'Well, I guess they thought \"WordRope\" and \"SentenceString\" just didn\\'t have the same ring to it!'\n",
" )\n",
" + AIMessage(\n",
" \"Why, he's probably chasing after the last cup of coffee in the office!\"\n",
" )\n",
")\n",
"\n",
"merged = merge_message_runs(messages)\n",
"print(\"\\n\\n\".join([repr(x) for x in merged]))"
]
},
{
"cell_type": "markdown",
"id": "1b2eee74-71c8-4168-b968-bca580c25d18",

View File

@@ -0,0 +1,395 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "503e36ae-ca62-4f8a-880c-4fe78ff5df93",
"metadata": {},
"source": [
"# How to return extra artifacts from a tool\n",
"\n",
":::info Prerequisites\n",
"This guide assumes familiarity with the following concepts:\n",
"\n",
"- [Tools](/docs/concepts/#tools)\n",
"- [Function/tool calling](/docs/concepts/#functiontool-calling)\n",
"\n",
":::\n",
"\n",
"Tools are utilities that can be called by a model, and whose outputs are designed to be fed back to a model. Sometimes, however, there are artifacts of a tool's execution that we want to make accessible to downstream components in our chain or agent, but that we don't want to expose to the model itself. For example if a tool returns a custom object, a dataframe or an image, we may want to pass some metadata about this output to the model without passing the actual output to the model. At the same time, we may want to be able to access this full output elsewhere, for example in downstream tools.\n",
"\n",
"The Tool and [ToolMessage](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.tool.ToolMessage.html) interfaces make it possible to distinguish between the parts of the tool output meant for the model (this is the ToolMessage.content) and those parts which are meant for use outside the model (ToolMessage.artifact).\n",
"\n",
":::info Requires ``langchain-core >= 0.2.19``\n",
"\n",
"This functionality was added in ``langchain-core == 0.2.19``. Please make sure your package is up to date.\n",
"\n",
":::\n",
"\n",
"## Defining the tool\n",
"\n",
"If we want our tool to distinguish between message content and other artifacts, we need to specify `response_format=\"content_and_artifact\"` when defining our tool and make sure that we return a tuple of (content, artifact):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "762b9199-885f-4946-9c98-cc54d72b0d76",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU \"langchain-core>=0.2.19\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b9eb179d-1f41-4748-9866-b3d3e8c73cd0",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"from typing import List, Tuple\n",
"\n",
"from langchain_core.tools import tool\n",
"\n",
"\n",
"@tool(response_format=\"content_and_artifact\")\n",
"def generate_random_ints(min: int, max: int, size: int) -> Tuple[str, List[int]]:\n",
" \"\"\"Generate size random ints in the range [min, max].\"\"\"\n",
" array = [random.randint(min, max) for _ in range(size)]\n",
" content = f\"Successfully generated array of {size} random ints in [{min}, {max}].\"\n",
" return content, array"
]
},
{
"cell_type": "markdown",
"id": "0ab05d25-af4a-4e5a-afe2-f090416d7ee7",
"metadata": {},
"source": [
"## Invoking the tool with ToolCall\n",
"\n",
"If we directly invoke our tool with just the tool arguments, you'll notice that we only get back the content part of the Tool output:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5e7d5e77-3102-4a59-8ade-e4e699dd1817",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Successfully generated array of 10 random ints in [0, 9].'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Failed to batch ingest runs: LangSmithRateLimitError('Rate limit exceeded for https://api.smith.langchain.com/runs/batch. HTTPError(\\'429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/batch\\', \\'{\"detail\":\"Monthly unique traces usage limit exceeded\"}\\')')\n"
]
}
],
"source": [
"generate_random_ints.invoke({\"min\": 0, \"max\": 9, \"size\": 10})"
]
},
{
"cell_type": "markdown",
"id": "30db7228-f04c-489e-afda-9a572eaa90a1",
"metadata": {},
"source": [
"In order to get back both the content and the artifact, we need to invoke our model with a ToolCall (which is just a dictionary with \"name\", \"args\", \"id\" and \"type\" keys), which has additional info needed to generate a ToolMessage like the tool call ID:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "da1d939d-a900-4b01-92aa-d19011a6b034",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ToolMessage(content='Successfully generated array of 10 random ints in [0, 9].', name='generate_random_ints', tool_call_id='123', artifact=[2, 8, 0, 6, 0, 0, 1, 5, 0, 0])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generate_random_ints.invoke(\n",
" {\n",
" \"name\": \"generate_random_ints\",\n",
" \"args\": {\"min\": 0, \"max\": 9, \"size\": 10},\n",
" \"id\": \"123\", # required\n",
" \"type\": \"tool_call\", # required\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "a3cfc03d-020b-42c7-b0f8-c824af19e45e",
"metadata": {},
"source": [
"## Using with a model\n",
"\n",
"With a [tool-calling model](/docs/how_to/tool_calling/), we can easily use a model to call our Tool and generate ToolMessages:\n",
"\n",
"```{=mdx}\n",
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
"\n",
"<ChatModelTabs\n",
" customVarName=\"llm\"\n",
"/>\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "74de0286-b003-4b48-9cdd-ecab435515ca",
"metadata": {},
"outputs": [],
"source": [
"# | echo: false\n",
"# | output: false\n",
"\n",
"from langchain_anthropic import ChatAnthropic\n",
"\n",
"llm = ChatAnthropic(model=\"claude-3-5-sonnet-20240620\", temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8a67424b-d19c-43df-ac7b-690bca42146c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'generate_random_ints',\n",
" 'args': {'min': 1, 'max': 24, 'size': 6},\n",
" 'id': 'toolu_01EtALY3Wz1DVYhv1TLvZGvE',\n",
" 'type': 'tool_call'}]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm_with_tools = llm.bind_tools([generate_random_ints])\n",
"\n",
"ai_msg = llm_with_tools.invoke(\"generate 6 positive ints less than 25\")\n",
"ai_msg.tool_calls"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "00c4e906-3ca8-41e8-a0be-65cb0db7d574",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ToolMessage(content='Successfully generated array of 6 random ints in [1, 24].', name='generate_random_ints', tool_call_id='toolu_01EtALY3Wz1DVYhv1TLvZGvE', artifact=[2, 20, 23, 8, 1, 15])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generate_random_ints.invoke(ai_msg.tool_calls[0])"
]
},
{
"cell_type": "markdown",
"id": "ddef2690-70de-4542-ab20-2337f77f3e46",
"metadata": {},
"source": [
"If we just pass in the tool call args, we'll only get back the content:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f4a6c9a6-0ffc-4b0e-a59f-f3c3d69d824d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Successfully generated array of 6 random ints in [1, 24].'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generate_random_ints.invoke(ai_msg.tool_calls[0][\"args\"])"
]
},
{
"cell_type": "markdown",
"id": "98d6443b-ff41-4d91-8523-b6274fc74ee5",
"metadata": {},
"source": [
"If we wanted to declaratively create a chain, we could do this:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "eb55ec23-95a4-464e-b886-d9679bf3aaa2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[ToolMessage(content='Successfully generated array of 1 random ints in [1, 5].', name='generate_random_ints', tool_call_id='toolu_01FwYhnkwDPJPbKdGq4ng6uD', artifact=[5])]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from operator import attrgetter\n",
"\n",
"chain = llm_with_tools | attrgetter(\"tool_calls\") | generate_random_ints.map()\n",
"\n",
"chain.invoke(\"give me a random number between 1 and 5\")"
]
},
{
"cell_type": "markdown",
"id": "4df46be2-babb-4bfe-a641-91cd3d03ffaf",
"metadata": {},
"source": [
"## Creating from BaseTool class\n",
"\n",
"If you want to create a BaseTool object directly, instead of decorating a function with `@tool`, you can do so like this:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9a9129e1-6aee-4a10-ad57-62ef3bf0276c",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.tools import BaseTool\n",
"\n",
"\n",
"class GenerateRandomFloats(BaseTool):\n",
" name: str = \"generate_random_floats\"\n",
" description: str = \"Generate size random floats in the range [min, max].\"\n",
" response_format: str = \"content_and_artifact\"\n",
"\n",
" ndigits: int = 2\n",
"\n",
" def _run(self, min: float, max: float, size: int) -> Tuple[str, List[float]]:\n",
" range_ = max - min\n",
" array = [\n",
" round(min + (range_ * random.random()), ndigits=self.ndigits)\n",
" for _ in range(size)\n",
" ]\n",
" content = f\"Generated {size} floats in [{min}, {max}], rounded to {self.ndigits} decimals.\"\n",
" return content, array\n",
"\n",
" # Optionally define an equivalent async method\n",
"\n",
" # async def _arun(self, min: float, max: float, size: int) -> Tuple[str, List[float]]:\n",
" # ..."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d7322619-f420-4b29-8ee5-023e693d0179",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Generated 3 floats in [0.1, 3.3333], rounded to 4 decimals.'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rand_gen = GenerateRandomFloats(ndigits=4)\n",
"rand_gen.invoke({\"min\": 0.1, \"max\": 3.3333, \"size\": 3})"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0892f277-23a6-4bb8-a0e9-59f533ac9750",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ToolMessage(content='Generated 3 floats in [0.1, 3.3333], rounded to 4 decimals.', name='generate_random_floats', tool_call_id='123', artifact=[1.5789, 2.464, 2.2719])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rand_gen.invoke(\n",
" {\n",
" \"name\": \"generate_random_floats\",\n",
" \"args\": {\"min\": 0.1, \"max\": 3.3333, \"size\": 3},\n",
" \"id\": \"123\",\n",
" \"type\": \"tool_call\",\n",
" }\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv-311",
"language": "python",
"name": "poetry-venv-311"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,132 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# How to access the RunnableConfig object within a custom tool\n",
"\n",
":::info Prerequisites\n",
"\n",
"This guide assumes familiarity with the following concepts:\n",
"\n",
"- [LangChain Tools](/docs/concepts/#tools)\n",
"- [Custom tools](/docs/how_to/custom_tools)\n",
"- [LangChain Expression Language (LCEL)](/docs/concepts/#langchain-expression-language-lcel)\n",
"- [Configuring runnable behavior](/docs/how_to/configure/)\n",
"\n",
":::\n",
"\n",
"If you have a tool that call chat models, retrievers, or other runnables, you may want to access internal events from those runnables or configure them with additional properties. This guide shows you how to manually pass parameters properly so that you can do this using the `astream_events()` method.\n",
"\n",
"Tools are runnables, and you can treat them the same way as any other runnable at the interface level - you can call `invoke()`, `batch()`, and `stream()` on them as normal. However, when writing custom tools, you may want to invoke other runnables like chat models or retrievers. In order to properly trace and configure those sub-invocations, you'll need to manually access and pass in the tool's current [`RunnableConfig`](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.config.RunnableConfig.html) object. This guide show you some examples of how to do that.\n",
"\n",
":::caution Compatibility\n",
"\n",
"This guide requires `langchain-core>=0.2.16`.\n",
"\n",
":::\n",
"\n",
"## Inferring by parameter type\n",
"\n",
"To access reference the active config object from your custom tool, you'll need to add a parameter to your tool's signature typed as `RunnableConfig`. When you invoke your tool, LangChain will inspect your tool's signature, look for a parameter typed as `RunnableConfig`, and if it exists, populate that parameter with the correct value.\n",
"\n",
"**Note:** The actual name of the parameter doesn't matter, only the typing.\n",
"\n",
"To illustrate this, define a custom tool that takes a two parameters - one typed as a string, the other typed as `RunnableConfig`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain_core"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.runnables import RunnableConfig\n",
"from langchain_core.tools import tool\n",
"\n",
"\n",
"@tool\n",
"async def reverse_tool(text: str, special_config_param: RunnableConfig) -> str:\n",
" \"\"\"A test tool that combines input text with a configurable parameter.\"\"\"\n",
" return (text + special_config_param[\"configurable\"][\"additional_field\"])[::-1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, if we invoke the tool with a `config` containing a `configurable` field, we can see that `additional_field` is passed through correctly:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'321cba'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await reverse_tool.ainvoke(\n",
" {\"text\": \"abc\"}, config={\"configurable\": {\"additional_field\": \"123\"}}\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Next steps\n",
"\n",
"You've now seen how to configure and stream events from within a tool. Next, check out the following guides for more on using tools:\n",
"\n",
"- [Stream events from child runs within a custom tool](/docs/how_to/tool_stream_events/)\n",
"- Pass [tool results back to a model](/docs/how_to/tool_results_pass_to_model)\n",
"\n",
"You can also check out some more specific uses of tool calling:\n",
"\n",
"- Building [tool-using chains and agents](/docs/how_to#tools)\n",
"- Getting [structured outputs](/docs/how_to/structured_output/) from models"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -6,12 +6,20 @@
"source": [
"# How to pass tool outputs to the model\n",
"\n",
"If we're using the model-generated tool invocations to actually call tools and want to pass the tool results back to the model, we can do so using `ToolMessage`s. First, let's define our tools and our model."
":::info Prerequisites\n",
"This guide assumes familiarity with the following concepts:\n",
"\n",
"- [Tools](/docs/concepts/#tools)\n",
"- [Function/tool calling](/docs/concepts/#functiontool-calling)\n",
"\n",
":::\n",
"\n",
"If we're using the model-generated tool invocations to actually call tools and want to pass the tool results back to the model, we can do so using `ToolMessage`s and `ToolCall`s. First, let's define our tools and our model."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -35,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -54,25 +62,32 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can use ``ToolMessage`` to pass back the output of the tool calls to the model."
"The nice thing about Tools is that if we invoke them with a ToolCall, we'll automatically get back a ToolMessage that can be fed back to the model: \n",
"\n",
":::info Requires ``langchain-core >= 0.2.19``\n",
"\n",
"This functionality was added in ``langchain-core == 0.2.19``. Please make sure your package is up to date.\n",
"\n",
":::"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[HumanMessage(content='What is 3 * 12? Also, what is 11 + 49?'),\n",
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_svc2GLSxNFALbaCAbSjMI9J8', 'function': {'arguments': '{\"a\": 3, \"b\": 12}', 'name': 'Multiply'}, 'type': 'function'}, {'id': 'call_r8jxte3zW6h3MEGV3zH2qzFh', 'function': {'arguments': '{\"a\": 11, \"b\": 49}', 'name': 'Add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 50, 'prompt_tokens': 105, 'total_tokens': 155}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_d9767fc5b9', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-a79ad1dd-95f1-4a46-b688-4c83f327a7b3-0', tool_calls=[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_svc2GLSxNFALbaCAbSjMI9J8'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_r8jxte3zW6h3MEGV3zH2qzFh'}]),\n",
" ToolMessage(content='36', tool_call_id='call_svc2GLSxNFALbaCAbSjMI9J8'),\n",
" ToolMessage(content='60', tool_call_id='call_r8jxte3zW6h3MEGV3zH2qzFh')]"
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Smg3NHJNxrKfAmd4f9GkaYn3', 'function': {'arguments': '{\"a\": 3, \"b\": 12}', 'name': 'multiply'}, 'type': 'function'}, {'id': 'call_55K1C0DmH6U5qh810gW34xZ0', 'function': {'arguments': '{\"a\": 11, \"b\": 49}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 49, 'prompt_tokens': 88, 'total_tokens': 137}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-56657feb-96dd-456c-ab8e-1857eab2ade0-0', tool_calls=[{'name': 'multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_Smg3NHJNxrKfAmd4f9GkaYn3', 'type': 'tool_call'}, {'name': 'add', 'args': {'a': 11, 'b': 49}, 'id': 'call_55K1C0DmH6U5qh810gW34xZ0', 'type': 'tool_call'}], usage_metadata={'input_tokens': 88, 'output_tokens': 49, 'total_tokens': 137}),\n",
" ToolMessage(content='36', name='multiply', tool_call_id='call_Smg3NHJNxrKfAmd4f9GkaYn3'),\n",
" ToolMessage(content='60', name='add', tool_call_id='call_55K1C0DmH6U5qh810gW34xZ0')]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "display_data"
"output_type": "execute_result"
}
],
"source": [
@@ -85,24 +100,25 @@
"messages.append(ai_msg)\n",
"for tool_call in ai_msg.tool_calls:\n",
" selected_tool = {\"add\": add, \"multiply\": multiply}[tool_call[\"name\"].lower()]\n",
" tool_output = selected_tool.invoke(tool_call[\"args\"])\n",
" messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n",
" tool_msg = selected_tool.invoke(tool_call)\n",
" messages.append(tool_msg)\n",
"messages"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='3 * 12 is 36 and 11 + 49 is 60.', response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 171, 'total_tokens': 189}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_d9767fc5b9', 'finish_reason': 'stop', 'logprobs': None}, id='run-20b52149-e00d-48ea-97cf-f8de7a255f8c-0')"
"AIMessage(content='3 * 12 is 36 and 11 + 49 is 60.', response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 153, 'total_tokens': 171}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-ba5032f0-f773-406d-a408-8314e66511d0-0', usage_metadata={'input_tokens': 153, 'output_tokens': 18, 'total_tokens': 171})"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "display_data"
"output_type": "execute_result"
}
],
"source": [
@@ -118,10 +134,24 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv-311",
"language": "python",
"name": "poetry-venv-311"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}

View File

@@ -28,17 +28,21 @@
"which shows how to create an agent that keeps track of a given user's favorite pets.\n",
":::\n",
"\n",
"There are times where tools need to use runtime values that should not be populated by the LLM. For example, the tool logic may require using the ID of the user who made the request. In this case, allowing the LLM to control the parameter is a security risk.\n",
"You may need to bind values to a tool that are only known at runtime. For example, the tool logic may require using the ID of the user who made the request.\n",
"\n",
"Instead, the LLM should only control the parameters of the tool that are meant to be controlled by the LLM, while other parameters (such as user ID) should be fixed by the application logic. These defined parameters should not be part of the tool's final schema.\n",
"Most of the time, such values should not be controlled by the LLM. In fact, allowing the LLM to control the user ID may lead to a security risk.\n",
"\n",
"This how-to guide shows some design patterns that create the tool dynamically at run time and binds appropriate values to them."
"Instead, the LLM should only control the parameters of the tool that are meant to be controlled by the LLM, while other parameters (such as user ID) should be fixed by the application logic.\n",
"\n",
"This how-to guide shows a simple design pattern that creates the tool dynamically at run time and binds to them appropriate values."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can bind them to chat models as follows:\n",
"\n",
"```{=mdx}\n",
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
"\n",
@@ -51,14 +55,25 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"# | output: false\n",
"# | echo: false\n",
"\n",
"%pip install -qU langchain_core langchain_openai\n",
"%pip install -qU langchain langchain_openai\n",
"\n",
"import os\n",
"from getpass import getpass\n",
@@ -75,17 +90,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using the `curry` utility function\n",
"# Passing request time information\n",
"\n",
":::caution Compatibility\n",
"\n",
"This function is only available in `langchain_core>=0.2.17`.\n",
"\n",
":::\n",
"\n",
"We can bind arguments to the tool's inner function via a utility wrapper. This will use a technique called [currying](https://en.wikipedia.org/wiki/Currying) to bind arguments to the function while also removing it from the function signature.\n",
"\n",
"Below, we initialize a tool that lists a user's favorite pet. It requires a `user_id` that we'll curry ahead of time."
"The idea is to create the tool dynamically at request time, and bind to it the appropriate information. For example,\n",
"this information may be the user ID as resolved from the request itself."
]
},
{
@@ -94,98 +102,18 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.tools import StructuredTool\n",
"from langchain_core.utils.curry import curry\n",
"from typing import List\n",
"\n",
"user_to_pets = {\"eugene\": [\"cats\"]}\n",
"\n",
"\n",
"def list_favorite_pets(user_id: str) -> None:\n",
" \"\"\"List favorite pets, if any.\"\"\"\n",
" return user_to_pets.get(user_id, [])\n",
"\n",
"\n",
"curried_function = curry(list_favorite_pets, user_id=\"eugene\")\n",
"\n",
"curried_tool = StructuredTool.from_function(curried_function)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we examine the schema of the curried tool, we can see that it no longer has `user_id` as part of its signature:"
"from langchain_core.output_parsers import JsonOutputParser\n",
"from langchain_core.tools import BaseTool, tool"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'title': 'list_favorite_petsSchema',\n",
" 'description': 'List favorite pets, if any.',\n",
" 'type': 'object',\n",
" 'properties': {}}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"curried_tool.input_schema.schema()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But if we invoke it, we can see that it returns Eugene's favorite pets, `cats`:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['cats']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"curried_tool.invoke({})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using scope\n",
"\n",
"We can achieve a similar result by wrapping the tool declarations themselves in a function. This lets us take advantage of the closure created by the wrapper to pass a variable into each tool. Here's an example:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from typing import List\n",
"\n",
"from langchain_core.tools import BaseTool, tool\n",
"\n",
"user_to_pets = {}\n",
"\n",
"\n",
@@ -205,7 +133,7 @@
"\n",
" @tool\n",
" def list_favorite_pets() -> None:\n",
" \"\"\"List favorite pets, if any.\"\"\"\n",
" \"\"\"List favorite pets if any.\"\"\"\n",
" return user_to_pets.get(user_id, [])\n",
"\n",
" return [update_favorite_pets, delete_favorite_pets, list_favorite_pets]"
@@ -215,12 +143,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Verify that the tools work correctly:"
"Verify that the tools work correctly"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -241,14 +169,21 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"\n",
"def handle_run_time_request(user_id: str, query: str):\n",
" \"\"\"Handle run time request.\"\"\"\n",
" tools = generate_tools_for_user(user_id)\n",
" llm_with_tools = llm.bind_tools(tools)\n",
" prompt = ChatPromptTemplate.from_messages(\n",
" [(\"system\", \"You are a helpful assistant.\")],\n",
" )\n",
" chain = prompt | llm_with_tools\n",
" return llm_with_tools.invoke(query)"
]
},
@@ -261,7 +196,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -269,10 +204,10 @@
"text/plain": [
"[{'name': 'update_favorite_pets',\n",
" 'args': {'pets': ['cats', 'parrots']},\n",
" 'id': 'call_c8agYHY1COFSAgwZR11OGCmQ'}]"
" 'id': 'call_jJvjPXsNbFO5MMgW0q84iqCN'}]"
]
},
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -313,7 +248,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.5"
"version": "3.11.4"
}
},
"nbformat": 4,

View File

@@ -4,25 +4,31 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# How to stream events from within a tool\n",
"# How to stream events from child runs within a custom tool\n",
"\n",
":::info Prerequisites\n",
"\n",
"This guide assumes familiarity with the following concepts:\n",
"- [LangChain Tools](/docs/concepts/#tools)\n",
"- [Custom tools](/docs/how_to/custom_tools)\n",
"- [Using stream events](/docs/how_to/streaming/#using-stream-events)\n",
"- [Accessing RunnableConfig within a custom tool](/docs/how_to/tool_configure/)\n",
"\n",
":::\n",
"\n",
"If you have tools that call LLMs, retrievers, or other runnables, you may want to access internal events from those runnables. This guide shows you a few ways you can do this using the `astream_events()` method.\n",
"If you have tools that call chat models, retrievers, or other runnables, you may want to access internal events from those runnables or configure them with additional properties. This guide shows you how to manually pass parameters properly so that you can do this using the `astream_events()` method.\n",
"\n",
":::caution Compatibility\n",
"\n",
"LangChain cannot automatically propagate configuration, including callbacks necessary for `astream_events()`, to child runnables if you are running `async` code in `python<=3.10`. This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
"\n",
"If you are running `python>=3.11`, configuration will automatically propagate to child runnables in async environments, and you don't need to access the `RunnableConfig` object for that tool as shown in this guide. However, it is still a good idea if your code may run in other Python versions.\n",
"\n",
"This guide also requires `langchain-core>=0.2.16`.\n",
"\n",
":::caution\n",
"LangChain cannot automatically propagate callbacks to child runnables if you are running async code in python<=3.10.\n",
" \n",
"This is a common reason why you may fail to see events being emitted from custom runnables or tools.\n",
":::\n",
"\n",
"We'll define a custom tool below that calls a chain that summarizes its input in a special way by prompting an LLM to return only 10 words, then reversing the output:\n",
"Say you have a custom tool that calls a chain that condenses its input by prompting a chat model to return only 10 words, then reversing the output. First, define it in a naive way:\n",
"\n",
"```{=mdx}\n",
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
@@ -40,7 +46,7 @@
"# | output: false\n",
"# | echo: false\n",
"\n",
"%pip install -qU langchain langchain_anthropic\n",
"%pip install -qU langchain langchain_anthropic langchain_core\n",
"\n",
"import os\n",
"from getpass import getpass\n",
@@ -65,7 +71,7 @@
"\n",
"\n",
"@tool\n",
"def special_summarization_tool(long_text: str) -> str:\n",
"async def special_summarization_tool(long_text: str) -> str:\n",
" \"\"\"A tool that summarizes input text using advanced techniques.\"\"\"\n",
" prompt = ChatPromptTemplate.from_template(\n",
" \"You are an expert writer. Summarize the following text in 10 words or less:\\n\\n{long_text}\"\n",
@@ -75,7 +81,7 @@
" return x[::-1]\n",
"\n",
" chain = prompt | model | StrOutputParser() | reverse\n",
" summary = chain.invoke({\"long_text\": long_text})\n",
" summary = await chain.ainvoke({\"long_text\": long_text})\n",
" return summary"
]
},
@@ -83,7 +89,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If you just invoke the tool directly, you can see that you only get the final response:"
"Invoking the tool directly works just fine:"
]
},
{
@@ -116,31 +122,90 @@
"Coming! Hang on a second.\n",
"\"\"\"\n",
"\n",
"special_summarization_tool.invoke({\"long_text\": LONG_TEXT})"
"await special_summarization_tool.ainvoke({\"long_text\": LONG_TEXT})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you wanted to access the raw output from the chat model, you could use the [`astream_events()`](/docs/how_to/streaming/#using-stream-events) method and look for `on_chat_model_end` events:"
"But if you wanted to access the raw output from the chat model rather than the full tool, you might try to use the [`astream_events()`](/docs/how_to/streaming/#using-stream-events) method and look for an `on_chat_model_end` event. Here's what happens:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"stream = special_summarization_tool.astream_events(\n",
" {\"long_text\": LONG_TEXT}, version=\"v2\"\n",
")\n",
"\n",
"async for event in stream:\n",
" if event[\"event\"] == \"on_chat_model_end\":\n",
" # Never triggers in python<=3.10!\n",
" print(event)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You'll notice (unless you're running through this guide in `python>=3.11`) that there are no chat model events emitted from the child run!\n",
"\n",
"This is because the example above does not pass the tool's config object into the internal chain. To fix this, redefine your tool to take a special parameter typed as `RunnableConfig` (see [this guide](/docs/how_to/tool_configure) for more details). You'll also need to pass that parameter through into the internal chain when executing it:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.runnables import RunnableConfig\n",
"\n",
"\n",
"@tool\n",
"async def special_summarization_tool_with_config(\n",
" long_text: str, config: RunnableConfig\n",
") -> str:\n",
" \"\"\"A tool that summarizes input text using advanced techniques.\"\"\"\n",
" prompt = ChatPromptTemplate.from_template(\n",
" \"You are an expert writer. Summarize the following text in 10 words or less:\\n\\n{long_text}\"\n",
" )\n",
"\n",
" def reverse(x: str):\n",
" return x[::-1]\n",
"\n",
" chain = prompt | model | StrOutputParser() | reverse\n",
" # Pass the \"config\" object as an argument to any executed runnables\n",
" summary = await chain.ainvoke({\"long_text\": long_text}, config=config)\n",
" return summary"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now try the same `astream_events()` call as before with your new tool:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'event': 'on_chat_model_end', 'data': {'output': AIMessage(content='Bee defies physics; Barry chooses outfit for graduation day.', response_metadata={'stop_reason': 'end_turn', 'stop_sequence': None}, id='run-195c0986-2ffa-43a3-9366-f2f96c42fe57', usage_metadata={'input_tokens': 182, 'output_tokens': 16, 'total_tokens': 198}), 'input': {'messages': [[HumanMessage(content=\"You are an expert writer. Summarize the following text in 10 words or less:\\n\\n\\nNARRATOR:\\n(Black screen with text; The sound of buzzing bees can be heard)\\nAccording to all known laws of aviation, there is no way a bee should be able to fly. Its wings are too small to get its fat little body off the ground. The bee, of course, flies anyway because bees don't care what humans think is impossible.\\nBARRY BENSON:\\n(Barry is picking out a shirt)\\nYellow, black. Yellow, black. Yellow, black. Yellow, black. Ooh, black and yellow! Let's shake it up a little.\\nJANET BENSON:\\nBarry! Breakfast is ready!\\nBARRY:\\nComing! Hang on a second.\\n\")]]}}, 'run_id': '195c0986-2ffa-43a3-9366-f2f96c42fe57', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['370919df-1bc3-43ae-aab2-8e112a4ddf47', 'de535624-278b-4927-9393-6d0cac3248df']}\n"
"{'event': 'on_chat_model_end', 'data': {'output': AIMessage(content='Bee defies physics; Barry chooses outfit for graduation day.', response_metadata={'stop_reason': 'end_turn', 'stop_sequence': None}, id='run-d23abc80-0dce-4f74-9d7b-fb98ca4f2a9e', usage_metadata={'input_tokens': 182, 'output_tokens': 16, 'total_tokens': 198}), 'input': {'messages': [[HumanMessage(content=\"You are an expert writer. Summarize the following text in 10 words or less:\\n\\n\\nNARRATOR:\\n(Black screen with text; The sound of buzzing bees can be heard)\\nAccording to all known laws of aviation, there is no way a bee should be able to fly. Its wings are too small to get its fat little body off the ground. The bee, of course, flies anyway because bees don't care what humans think is impossible.\\nBARRY BENSON:\\n(Barry is picking out a shirt)\\nYellow, black. Yellow, black. Yellow, black. Yellow, black. Ooh, black and yellow! Let's shake it up a little.\\nJANET BENSON:\\nBarry! Breakfast is ready!\\nBARRY:\\nComing! Hang on a second.\\n\")]]}}, 'run_id': 'd23abc80-0dce-4f74-9d7b-fb98ca4f2a9e', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['f25c41fe-8972-4893-bc40-cecf3922c1fa']}\n"
]
}
],
"source": [
"stream = special_summarization_tool.astream_events(\n",
"stream = special_summarization_tool_with_config.astream_events(\n",
" {\"long_text\": LONG_TEXT}, version=\"v2\"\n",
")\n",
"\n",
@@ -153,38 +218,38 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"And you can see that you get the raw response from the chat model.\n",
"Awesome! This time there's an event emitted.\n",
"\n",
"`astream_events()` will automatically call internal runnables in a chain with streaming enabled if possible, so if you wanted to a stream of tokens as they are generated from the chat model, you could simply filter our calls to look for `on_chat_model_stream` events with no other changes:"
"For streaming, `astream_events()` automatically calls internal runnables in a chain with streaming enabled if possible, so if you wanted to a stream of tokens as they are generated from the chat model, you could simply filter to look for `on_chat_model_stream` events with no other changes:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', usage_metadata={'input_tokens': 182, 'output_tokens': 0, 'total_tokens': 182})}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='Bee', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' def', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='ies physics', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=';', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' Barry', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' cho', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='oses outfit', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' for', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' graduation', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' day', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='.', id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3')}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='', response_metadata={'stop_reason': 'end_turn', 'stop_sequence': None}, id='run-cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', usage_metadata={'input_tokens': 0, 'output_tokens': 16, 'total_tokens': 16})}, 'run_id': 'cd8c1bd9-64d8-463c-a4d7-4bceed7911b3', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['8ddd1325-07c4-4213-8a2f-4462db8c6c70', '9f8654b4-b3f6-414e-b41d-dd201342a2fa']}\n"
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42', usage_metadata={'input_tokens': 182, 'output_tokens': 0, 'total_tokens': 182})}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='Bee', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' def', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='ies physics', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=';', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' Barry', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' cho', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='oses outfit', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' for', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' graduation', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' day', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='.', id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42')}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='', response_metadata={'stop_reason': 'end_turn', 'stop_sequence': None}, id='run-f24ab147-0b82-4e63-810a-b12bd8d1fb42', usage_metadata={'input_tokens': 0, 'output_tokens': 16, 'total_tokens': 16})}, 'run_id': 'f24ab147-0b82-4e63-810a-b12bd8d1fb42', 'name': 'ChatAnthropic', 'tags': ['seq:step:2'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['385f3612-417c-4a70-aae0-cce3a5ba6fb6']}\n"
]
}
],
"source": [
"stream = special_summarization_tool.astream_events(\n",
"stream = special_summarization_tool_with_config.astream_events(\n",
" {\"long_text\": LONG_TEXT}, version=\"v2\"\n",
")\n",
"\n",
@@ -193,65 +258,14 @@
" print(event)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that you still have access to the final tool response as well. You can access it by looking for an `on_tool_end` event.\n",
"\n",
"To make events your tool emits easier to identify, you can also add identifiers to runnables using the `with_config()` method. `run_name` will apply to only to the runnable you attach it to, while `tags` will be inherited by runnables called within your initial runnable.\n",
"\n",
"Let's redeclare the tool with a tag, then run it with `astream_events()` with some filters. You should only see streamed events from the chat model and the final tool output:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630', usage_metadata={'input_tokens': 182, 'output_tokens': 0, 'total_tokens': 182})}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='Bee', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' def', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='ies physics', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=';', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' Barry', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' cho', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='oses outfit', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' for', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' graduation', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content=' day', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='.', id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630')}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_chat_model_stream', 'data': {'chunk': AIMessageChunk(content='', response_metadata={'stop_reason': 'end_turn', 'stop_sequence': None}, id='run-696f4fc8-6c6f-46a0-8c82-e2e3f7625630', usage_metadata={'input_tokens': 0, 'output_tokens': 16, 'total_tokens': 16})}, 'run_id': '696f4fc8-6c6f-46a0-8c82-e2e3f7625630', 'name': 'ChatAnthropic', 'tags': ['seq:step:2', 'bee_movie'], 'metadata': {'ls_provider': 'anthropic', 'ls_model_name': 'claude-3-5-sonnet-20240620', 'ls_model_type': 'chat', 'ls_temperature': 0.0, 'ls_max_tokens': 1024}, 'parent_ids': ['49d9d7d3-2b02-4964-a6c5-12f57a063146', '8922d0e3-4199-4ba5-9a7a-fc4f2fca3e72']}\n",
"{'event': 'on_tool_end', 'data': {'output': '.yad noitaudarg rof tiftuo sesoohc yrraB ;scisyhp seifed eeB'}, 'run_id': '49d9d7d3-2b02-4964-a6c5-12f57a063146', 'name': 'special_summarization_tool', 'tags': ['bee_movie'], 'metadata': {}, 'parent_ids': []}\n"
]
}
],
"source": [
"tagged_tool = special_summarization_tool.with_config({\"tags\": [\"bee_movie\"]})\n",
"\n",
"stream = tagged_tool.astream_events(\n",
" {\"long_text\": LONG_TEXT}, version=\"v2\", include_tags=[\"bee_movie\"]\n",
")\n",
"\n",
"async for event in stream:\n",
" event_type = event[\"event\"]\n",
" if event_type == \"on_chat_model_stream\" or event_type == \"on_tool_end\":\n",
" print(event)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Next steps\n",
"\n",
"Now you've learned how to stream events from within a tool. Next, you can learn more about how to use tools:\n",
"You've now seen how to stream events from within a tool. Next, check out the following guides for more on using tools:\n",
"\n",
"- Bind [model-specific tools](/docs/how_to/tools_model_specific/)\n",
"- Pass [runtime values to tools](/docs/how_to/tool_runtime)\n",
"- Pass [tool results back to a model](/docs/how_to/tool_results_pass_to_model)\n",
"\n",

View File

@@ -419,13 +419,13 @@
"Invoking: `exponentiate` with `{'base': 405, 'exponent': 2}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3m164025\u001b[0m\u001b[32;1m\u001b[1;3mThe result of taking 3 to the fifth power is 243. \n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3m13286025\u001b[0m\u001b[32;1m\u001b[1;3mThe result of taking 3 to the fifth power is 243. \n",
"\n",
"The sum of twelve and three is 15. \n",
"\n",
"Multiplying 243 by 15 gives 3645. \n",
"\n",
"Finally, squaring 3645 gives 164025.\u001b[0m\n",
"Finally, squaring 3645 gives 13286025.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -434,7 +434,7 @@
"data": {
"text/plain": [
"{'input': 'Take 3 to the fifth power and multiply that by the sum of twelve and three, then square the whole result',\n",
" 'output': 'The result of taking 3 to the fifth power is 243. \\n\\nThe sum of twelve and three is 15. \\n\\nMultiplying 243 by 15 gives 3645. \\n\\nFinally, squaring 3645 gives 164025.'}"
" 'output': 'The result of taking 3 to the fifth power is 243. \\n\\nThe sum of twelve and three is 15. \\n\\nMultiplying 243 by 15 gives 3645. \\n\\nFinally, squaring 3645 gives 13286025.'}"
]
},
"execution_count": 18,

View File

@@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 1,
"id": "10ad9224",
"metadata": {
"ExecuteTime": {
@@ -1809,7 +1809,6 @@
"cell_type": "markdown",
"id": "0c69d84d",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
@@ -1891,7 +1890,6 @@
"cell_type": "markdown",
"id": "5da41b77",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
@@ -2149,6 +2147,7 @@
},
{
"cell_type": "markdown",
"id": "2ac1a8c7",
"metadata": {},
"source": [
"## SingleStoreDB Semantic Cache\n",
@@ -2173,6 +2172,353 @@
")"
]
},
{
"cell_type": "markdown",
"id": "7019c991-0101-4f9c-b212-5729a5471293",
"metadata": {},
"source": [
"## Couchbase Caches\n",
"\n",
"Use [Couchbase](https://couchbase.com/) as a cache for prompts and responses."
]
},
{
"cell_type": "markdown",
"id": "d6aac680-ba32-4c19-8864-6471cf0e7d5a",
"metadata": {},
"source": [
"### Couchbase Cache\n",
"\n",
"The standard cache that looks for an exact match of the user prompt."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9b4764e4-c75f-4185-b326-524287a826be",
"metadata": {},
"outputs": [],
"source": [
"# Create couchbase connection object\n",
"from datetime import timedelta\n",
"\n",
"from couchbase.auth import PasswordAuthenticator\n",
"from couchbase.cluster import Cluster\n",
"from couchbase.options import ClusterOptions\n",
"from langchain_couchbase.cache import CouchbaseCache\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"COUCHBASE_CONNECTION_STRING = (\n",
" \"couchbase://localhost\" # or \"couchbases://localhost\" if using TLS\n",
")\n",
"DB_USERNAME = \"Administrator\"\n",
"DB_PASSWORD = \"Password\"\n",
"\n",
"auth = PasswordAuthenticator(DB_USERNAME, DB_PASSWORD)\n",
"options = ClusterOptions(auth)\n",
"cluster = Cluster(COUCHBASE_CONNECTION_STRING, options)\n",
"\n",
"# Wait until the cluster is ready for use.\n",
"cluster.wait_until_ready(timedelta(seconds=5))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4b5e73c5-92c1-4eab-84e2-77924ea9c123",
"metadata": {},
"outputs": [],
"source": [
"# Specify the bucket, scope and collection to store the cached documents\n",
"BUCKET_NAME = \"langchain-testing\"\n",
"SCOPE_NAME = \"_default\"\n",
"COLLECTION_NAME = \"_default\"\n",
"\n",
"set_llm_cache(\n",
" CouchbaseCache(\n",
" cluster=cluster,\n",
" bucket_name=BUCKET_NAME,\n",
" scope_name=SCOPE_NAME,\n",
" collection_name=COLLECTION_NAME,\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "db8d28cc-8d93-47b4-8326-57a29a06fb3c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 22.2 ms, sys: 14 ms, total: 36.2 ms\n",
"Wall time: 938 ms\n"
]
},
{
"data": {
"text/plain": [
"\"\\n\\nWhy couldn't the bicycle stand up by itself? Because it was two-tired!\""
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The first time, it is not yet in the cache, so it should take longer\n",
"llm.invoke(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b470dc81-2e7f-4743-9435-ce9071394eea",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 53 ms, sys: 29 ms, total: 82 ms\n",
"Wall time: 84.2 ms\n"
]
},
{
"data": {
"text/plain": [
"\"\\n\\nWhy couldn't the bicycle stand up by itself? Because it was two-tired!\""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The second time, it is in the cache, so it should be much faster\n",
"llm.invoke(\"Tell me a joke\")"
]
},
{
"cell_type": "markdown",
"id": "43626f33-d184-4260-b641-c9341cef5842",
"metadata": {},
"source": [
"### Couchbase Semantic Cache\n",
"Semantic caching allows users to retrieve cached prompts based on semantic similarity between the user input and previously cached inputs. Under the hood it uses Couchbase as both a cache and a vectorstore. This needs an appropriate Vector Search Index defined to work. Please look at the usage example on how to set up the index."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6b470c03-d7fe-4270-89e1-638251619a53",
"metadata": {},
"outputs": [],
"source": [
"# Create Couchbase connection object\n",
"from datetime import timedelta\n",
"\n",
"from couchbase.auth import PasswordAuthenticator\n",
"from couchbase.cluster import Cluster\n",
"from couchbase.options import ClusterOptions\n",
"from langchain_couchbase.cache import CouchbaseSemanticCache\n",
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
"\n",
"COUCHBASE_CONNECTION_STRING = (\n",
" \"couchbase://localhost\" # or \"couchbases://localhost\" if using TLS\n",
")\n",
"DB_USERNAME = \"Administrator\"\n",
"DB_PASSWORD = \"Password\"\n",
"\n",
"auth = PasswordAuthenticator(DB_USERNAME, DB_PASSWORD)\n",
"options = ClusterOptions(auth)\n",
"cluster = Cluster(COUCHBASE_CONNECTION_STRING, options)\n",
"\n",
"# Wait until the cluster is ready for use.\n",
"cluster.wait_until_ready(timedelta(seconds=5))"
]
},
{
"cell_type": "markdown",
"id": "f831bc4c-f330-4bd7-9b80-76771d91827e",
"metadata": {},
"source": [
"Notes:\n",
"- The search index for the semantic cache needs to be defined before using the semantic cache. \n",
"- The optional parameter, `score_threshold` in the Semantic Cache that you can use to tune the results of the semantic search.\n",
"\n",
"### How to Import an Index to the Full Text Search service?\n",
" - [Couchbase Server](https://docs.couchbase.com/server/current/search/import-search-index.html)\n",
" - Click on Search -> Add Index -> Import\n",
" - Copy the following Index definition in the Import screen\n",
" - Click on Create Index to create the index.\n",
" - [Couchbase Capella](https://docs.couchbase.com/cloud/search/import-search-index.html)\n",
" - Copy the index definition to a new file `index.json`\n",
" - Import the file in Capella using the instructions in the documentation.\n",
" - Click on Create Index to create the index.\n",
"\n",
"#### Example index for the vector search. \n",
" ```\n",
" {\n",
" \"type\": \"fulltext-index\",\n",
" \"name\": \"langchain-testing._default.semantic-cache-index\",\n",
" \"sourceType\": \"gocbcore\",\n",
" \"sourceName\": \"langchain-testing\",\n",
" \"planParams\": {\n",
" \"maxPartitionsPerPIndex\": 1024,\n",
" \"indexPartitions\": 16\n",
" },\n",
" \"params\": {\n",
" \"doc_config\": {\n",
" \"docid_prefix_delim\": \"\",\n",
" \"docid_regexp\": \"\",\n",
" \"mode\": \"scope.collection.type_field\",\n",
" \"type_field\": \"type\"\n",
" },\n",
" \"mapping\": {\n",
" \"analysis\": {},\n",
" \"default_analyzer\": \"standard\",\n",
" \"default_datetime_parser\": \"dateTimeOptional\",\n",
" \"default_field\": \"_all\",\n",
" \"default_mapping\": {\n",
" \"dynamic\": true,\n",
" \"enabled\": false\n",
" },\n",
" \"default_type\": \"_default\",\n",
" \"docvalues_dynamic\": false,\n",
" \"index_dynamic\": true,\n",
" \"store_dynamic\": true,\n",
" \"type_field\": \"_type\",\n",
" \"types\": {\n",
" \"_default.semantic-cache\": {\n",
" \"dynamic\": false,\n",
" \"enabled\": true,\n",
" \"properties\": {\n",
" \"embedding\": {\n",
" \"dynamic\": false,\n",
" \"enabled\": true,\n",
" \"fields\": [\n",
" {\n",
" \"dims\": 1536,\n",
" \"index\": true,\n",
" \"name\": \"embedding\",\n",
" \"similarity\": \"dot_product\",\n",
" \"type\": \"vector\",\n",
" \"vector_index_optimized_for\": \"recall\"\n",
" }\n",
" ]\n",
" },\n",
" \"metadata\": {\n",
" \"dynamic\": true,\n",
" \"enabled\": true\n",
" },\n",
" \"text\": {\n",
" \"dynamic\": false,\n",
" \"enabled\": true,\n",
" \"fields\": [\n",
" {\n",
" \"index\": true,\n",
" \"name\": \"text\",\n",
" \"store\": true,\n",
" \"type\": \"text\"\n",
" }\n",
" ]\n",
" }\n",
" }\n",
" }\n",
" }\n",
" },\n",
" \"store\": {\n",
" \"indexType\": \"scorch\",\n",
" \"segmentVersion\": 16\n",
" }\n",
" },\n",
" \"sourceParams\": {}\n",
" }\n",
" ```"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ae0766c8-ea34-4604-b0dc-cf2bbe8077f4",
"metadata": {},
"outputs": [],
"source": [
"BUCKET_NAME = \"langchain-testing\"\n",
"SCOPE_NAME = \"_default\"\n",
"COLLECTION_NAME = \"semantic-cache\"\n",
"INDEX_NAME = \"semantic-cache-index\"\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"cache = CouchbaseSemanticCache(\n",
" cluster=cluster,\n",
" embedding=embeddings,\n",
" bucket_name=BUCKET_NAME,\n",
" scope_name=SCOPE_NAME,\n",
" collection_name=COLLECTION_NAME,\n",
" index_name=INDEX_NAME,\n",
" score_threshold=0.8,\n",
")\n",
"\n",
"set_llm_cache(cache)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a2e82743-10ea-4319-b43e-193475ae5449",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"The average lifespan of a dog is around 12 years, but this can vary depending on the breed, size, and overall health of the individual dog. Some smaller breeds may live longer, while larger breeds may have shorter lifespans. Proper care, diet, and exercise can also play a role in extending a dog's lifespan.\n",
"CPU times: user 826 ms, sys: 2.46 s, total: 3.28 s\n",
"Wall time: 2.87 s\n"
]
}
],
"source": [
"%%time\n",
"# The first time, it is not yet in the cache, so it should take longer\n",
"print(llm.invoke(\"How long do dogs live?\"))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c36f4e29-d872-4334-a1f1-0e6d10c5d9f2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"The average lifespan of a dog is around 12 years, but this can vary depending on the breed, size, and overall health of the individual dog. Some smaller breeds may live longer, while larger breeds may have shorter lifespans. Proper care, diet, and exercise can also play a role in extending a dog's lifespan.\n",
"CPU times: user 9.82 ms, sys: 2.61 ms, total: 12.4 ms\n",
"Wall time: 311 ms\n"
]
}
],
"source": [
"%%time\n",
"# The second time, it is in the cache, so it should be much faster\n",
"print(llm.invoke(\"What is the expected lifespan of a dog?\"))"
]
},
{
"cell_type": "markdown",
"id": "ae1f5e1c-085e-4998-9f2d-b5867d2c3d5b",
@@ -2228,7 +2574,9 @@
"| langchain_core.caches | [InMemoryCache](https://api.python.langchain.com/en/latest/caches/langchain_core.caches.InMemoryCache.html) |\n",
"| langchain_elasticsearch.cache | [ElasticsearchCache](https://api.python.langchain.com/en/latest/cache/langchain_elasticsearch.cache.ElasticsearchCache.html) |\n",
"| langchain_mongodb.cache | [MongoDBAtlasSemanticCache](https://api.python.langchain.com/en/latest/cache/langchain_mongodb.cache.MongoDBAtlasSemanticCache.html) |\n",
"| langchain_mongodb.cache | [MongoDBCache](https://api.python.langchain.com/en/latest/cache/langchain_mongodb.cache.MongoDBCache.html) |\n"
"| langchain_mongodb.cache | [MongoDBCache](https://api.python.langchain.com/en/latest/cache/langchain_mongodb.cache.MongoDBCache.html) |\n",
"| langchain_couchbase.cache | [CouchbaseCache](https://api.python.langchain.com/en/latest/cache/langchain_couchbase.cache.CouchbaseCache.html) |\n",
"| langchain_couchbase.cache | [CouchbaseSemanticCache](https://api.python.langchain.com/en/latest/cache/langchain_couchbase.cache.CouchbaseSemanticCache.html) |\n"
]
},
{
@@ -2256,7 +2604,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.13"
}
},
"nbformat": 4,

View File

@@ -27,7 +27,7 @@
"outputs": [],
"source": [
"# Install the package\n",
"%pip install --upgrade --quiet dashscope"
"%pip install --upgrade --quiet langchain-community dashscope"
]
},
{

View File

@@ -27,3 +27,65 @@ See a [usage example](/docs/integrations/document_loaders/couchbase).
```python
from langchain_community.document_loaders.couchbase import CouchbaseLoader
```
## LLM Caches
### CouchbaseCache
Use Couchbase as a cache for prompts and responses.
See a [usage example](/docs/integrations/llm_caching/#couchbase-cache).
To import this cache:
```python
from langchain_couchbase.cache import CouchbaseCache
```
To use this cache with your LLMs:
```python
from langchain_core.globals import set_llm_cache
cluster = couchbase_cluster_connection_object
set_llm_cache(
CouchbaseCache(
cluster=cluster,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=COLLECTION_NAME,
)
)
```
### CouchbaseSemanticCache
Semantic caching allows users to retrieve cached prompts based on the semantic similarity between the user input and previously cached inputs. Under the hood it uses Couchbase as both a cache and a vectorstore.
The CouchbaseSemanticCache needs a Search Index defined to work. Please look at the [usage example](/docs/integrations/vectorstores/couchbase) on how to set up the index.
See a [usage example](/docs/integrations/llm_caching/#couchbase-semantic-cache).
To import this cache:
```python
from langchain_couchbase.cache import CouchbaseSemanticCache
```
To use this cache with your LLMs:
```python
from langchain_core.globals import set_llm_cache
# use any embedding provider...
from langchain_openai.Embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
cluster = couchbase_cluster_connection_object
set_llm_cache(
CouchbaseSemanticCache(
cluster=cluster,
embedding = embeddings,
bucket_name=BUCKET_NAME,
scope_name=SCOPE_NAME,
collection_name=COLLECTION_NAME,
index_name=INDEX_NAME,
)
)
```

View File

@@ -61,7 +61,7 @@ When ready to deploy, you can self-host models with NVIDIA NIM—which is includ
```python
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings, NVIDIARerank
# connect to an chat NIM running at localhost:8000, specifyig a specific model
# connect to a chat NIM running at localhost:8000, specifying a model
llm = ChatNVIDIA(base_url="http://localhost:8000/v1", model="meta/llama3-8b-instruct")
# connect to an embedding NIM running at localhost:8080

View File

@@ -202,7 +202,7 @@ Prem Templates are also available for Streaming too.
## Prem Embeddings
In this section we are going to dicuss how we can get access to different embedding model using `PremEmbeddings` with LangChain. Lets start by importing our modules and setting our API Key.
In this section we cover how we can get access to different embedding models using `PremEmbeddings` with LangChain. Let's start by importing our modules and setting our API Key.
```python
import os

View File

@@ -309,9 +309,9 @@
"documents = TextLoader(\"../../how_to/state_of_the_union.txt\").load()\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n",
"texts = text_splitter.split_documents(documents)\n",
"retriever = FAISS.from_documents(texts, CohereEmbeddings()).as_retriever(\n",
" search_kwargs={\"k\": 20}\n",
")\n",
"retriever = FAISS.from_documents(\n",
" texts, CohereEmbeddings(model=\"embed-english-v3.0\")\n",
").as_retriever(search_kwargs={\"k\": 20})\n",
"\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = retriever.invoke(query)\n",
@@ -324,7 +324,8 @@
"metadata": {},
"source": [
"## Doing reranking with CohereRerank\n",
"Now let's wrap our base retriever with a `ContextualCompressionRetriever`. We'll add an `CohereRerank`, uses the Cohere rerank endpoint to rerank the returned results."
"Now let's wrap our base retriever with a `ContextualCompressionRetriever`. We'll add an `CohereRerank`, uses the Cohere rerank endpoint to rerank the returned results.\n",
"Do note that it is mandatory to specify the model name in CohereRerank!"
]
},
{
@@ -339,7 +340,7 @@
"from langchain_community.llms import Cohere\n",
"\n",
"llm = Cohere(temperature=0)\n",
"compressor = CohereRerank()\n",
"compressor = CohereRerank(model=\"rerank-english-v3.0\")\n",
"compression_retriever = ContextualCompressionRetriever(\n",
" base_compressor=compressor, base_retriever=retriever\n",
")\n",

View File

@@ -40,7 +40,9 @@
"metadata": {},
"outputs": [],
"source": [
"embeddings = CohereEmbeddings(model=\"embed-english-light-v3.0\")"
"embeddings = CohereEmbeddings(\n",
" model=\"embed-english-light-v3.0\"\n",
") # It is mandatory to pass a model parameter to initialize the CohereEmbeddings object"
]
},
{

View File

@@ -169,6 +169,23 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Specify additional properties for the Azure client such as the following https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/README.md#configurations\n",
"vector_store: AzureSearch = AzureSearch(\n",
" azure_search_endpoint=vector_store_address,\n",
" azure_search_key=vector_store_password,\n",
" index_name=index_name,\n",
" embedding_function=embeddings.embed_query,\n",
" # Configure max retries for the Azure client\n",
" additional_search_client_options={\"retry_total\": 4},\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},

View File

@@ -78,7 +78,7 @@
"# See docker command above to launch a postgres instance with pgvector enabled.\n",
"connection = \"postgresql+psycopg://langchain:langchain@localhost:6024/langchain\" # Uses psycopg3!\n",
"collection_name = \"my_docs\"\n",
"embeddings = CohereEmbeddings()\n",
"embeddings = CohereEmbeddings(model=\"embed-english-v3.0\")\n",
"\n",
"vectorstore = PGVector(\n",
" embeddings=embeddings,\n",

View File

@@ -107,7 +107,7 @@
"```\n",
"## Preview\n",
"\n",
"In this guide well build a QA app over as website. The specific website we will use is the [LLM Powered Autonomous\n",
"In this guide well build an app that answers questions about the content of a website. The specific website we will use is the [LLM Powered Autonomous\n",
"Agents](https://lilianweng.github.io/posts/2023-06-23-agent/) blog post\n",
"by Lilian Weng, which allows us to ask questions about the contents of\n",
"the post.\n",

View File

@@ -25,7 +25,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://api.endpoints.anyscale.com/v1"
DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
class ChatAnyscale(ChatOpenAI):

View File

@@ -141,9 +141,8 @@ class CustomOpenAIChatContentFormatter(ContentFormatterBase):
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return ChatGeneration(
message=BaseMessage(
message=AIMessage(
content=choice.strip(),
type="assistant",
),
generation_info=None,
)
@@ -158,7 +157,9 @@ class CustomOpenAIChatContentFormatter(ContentFormatterBase):
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return ChatGeneration(
message=BaseMessage(
message=AIMessage(content=choice["message"]["content"].strip())
if choice["message"]["role"] == "assistant"
else BaseMessage(
content=choice["message"]["content"].strip(),
type=choice["message"]["role"],
),

View File

@@ -48,6 +48,7 @@ from langchain_core.messages import (
ToolMessage,
)
from langchain_core.messages.tool import ToolCall
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
@@ -96,7 +97,7 @@ def _parse_tool_calling(tool_call: dict) -> ToolCall:
name = tool_call["function"].get("name", "")
args = json.loads(tool_call["function"]["arguments"])
id = tool_call.get("id")
return ToolCall(name=name, args=args, id=id)
return create_tool_call(name=name, args=args, id=id)
def _convert_to_tool_calling(tool_call: ToolCall) -> Dict[str, Any]:

View File

@@ -36,9 +36,11 @@ from langchain_core.messages import (
InvalidToolCall,
SystemMessage,
ToolCall,
ToolCallChunk,
ToolMessage,
)
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
@@ -63,7 +65,7 @@ def _result_to_chunked_message(generated_result: ChatResult) -> ChatGenerationCh
message = generated_result.generations[0].message
if isinstance(message, AIMessage) and message.tool_calls is not None:
tool_call_chunks = [
ToolCallChunk(
create_tool_call_chunk(
name=tool_call["name"],
args=json.dumps(tool_call["args"]),
id=tool_call["id"],
@@ -189,7 +191,7 @@ def _extract_tool_calls_from_edenai_response(
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(
ToolCall(
create_tool_call(
name=raw_tool_call["name"],
args=json.loads(raw_tool_call["arguments"]),
id=raw_tool_call["id"],
@@ -197,7 +199,7 @@ def _extract_tool_calls_from_edenai_response(
)
except json.JSONDecodeError as exc:
invalid_tool_calls.append(
InvalidToolCall(
create_invalid_tool_call(
name=raw_tool_call.get("name"),
args=raw_tool_call.get("arguments"),
id=raw_tool_call.get("id"),

View File

@@ -144,7 +144,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
elif (
isinstance(temp_image_url, dict) and "url" in temp_image_url
):
image_url = temp_image_url
image_url = temp_image_url["url"]
else:
raise ValueError(
"Only string image_url or dict with string 'url' "

View File

@@ -60,7 +60,7 @@ class HuggingFaceCrossEncoder(BaseModel, BaseCrossEncoder):
List of scores, one for each pair.
"""
scores = self.client.predict(text_pairs)
# Somes models e.g bert-multilingual-passage-reranking-msmarco
# Some models e.g bert-multilingual-passage-reranking-msmarco
# gives two score not_relevant and relevant as compare with the query.
if len(scores.shape) > 1: # we are going to get the relevant scores
scores = map(lambda x: x[1], scores)

View File

@@ -60,7 +60,7 @@ class AscendEmbeddings(Embeddings, BaseModel):
raise ValueError("model_path is required")
if not os.access(values["model_path"], os.F_OK):
raise FileNotFoundError(
f"Unabled to find valid model path in [{values['model_path']}]"
f"Unable to find valid model path in [{values['model_path']}]"
)
try:
import torch_npu

View File

@@ -555,10 +555,11 @@ class Neo4jGraph(GraphStore):
el["labelsOrTypes"] == [BASE_ENTITY_LABEL]
and el["properties"] == ["id"]
for el in self.structured_schema.get("metadata", {}).get(
"constraint"
"constraint", []
)
]
)
if not constraint_exists:
# Create constraint
self.query(

View File

@@ -1,5 +1,6 @@
import json
import logging
import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
from langchain_core._api.deprecation import deprecated
@@ -11,7 +12,6 @@ from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
pre_init,
)
@@ -177,16 +177,17 @@ class HuggingFaceEndpoint(LLM):
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
try:
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
login(token=huggingfacehub_api_token)
except Exception as e:
raise ValueError(
"Could not authenticate with huggingface_hub. "
"Please check your API token."
) from e
huggingfacehub_api_token = values["huggingfacehub_api_token"] or os.getenv(
"HUGGINGFACEHUB_API_TOKEN"
)
if huggingfacehub_api_token is not None:
try:
login(token=huggingfacehub_api_token)
except Exception as e:
raise ValueError(
"Could not authenticate with huggingface_hub. "
"Please check your API token."
) from e
from huggingface_hub import AsyncInferenceClient, InferenceClient

View File

@@ -72,7 +72,7 @@ class SQLStore(BaseStore[str, bytes]):
from langchain_rag.storage import SQLStore
# Instantiate the SQLStore with the root path
sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:")
sql_store = SQLStore(namespace="test", db_url="sqlite://:memory:")
# Set values for keys
sql_store.mset([("key1", b"value1"), ("key2", b"value2")])

View File

@@ -80,7 +80,7 @@ class SemanticScholarAPIWrapper(BaseModel):
f"Published year: {getattr(item, 'year', None)}\n"
f"Title: {getattr(item, 'title', None)}\n"
f"Authors: {authors}\n"
f"Astract: {getattr(item, 'abstract', None)}\n"
f"Abstract: {getattr(item, 'abstract', None)}\n"
)
if documents:

View File

@@ -86,6 +86,7 @@ def _get_search_client(
user_agent: Optional[str] = "langchain",
cors_options: Optional[CorsOptions] = None,
async_: bool = False,
additional_search_client_options: Optional[Dict[str, Any]] = None,
) -> Union[SearchClient, AsyncSearchClient]:
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import ResourceNotFoundError
@@ -109,6 +110,7 @@ def _get_search_client(
VectorSearchProfile,
)
additional_search_client_options = additional_search_client_options or {}
default_fields = default_fields or []
if key is None:
credential = DefaultAzureCredential()
@@ -225,6 +227,7 @@ def _get_search_client(
index_name=index_name,
credential=credential,
user_agent=user_agent,
**additional_search_client_options,
)
else:
return AsyncSearchClient(
@@ -232,6 +235,7 @@ def _get_search_client(
index_name=index_name,
credential=credential,
user_agent=user_agent,
**additional_search_client_options,
)
@@ -256,6 +260,7 @@ class AzureSearch(VectorStore):
cors_options: Optional[CorsOptions] = None,
*,
vector_search_dimensions: Optional[int] = None,
additional_search_client_options: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
try:
@@ -320,6 +325,22 @@ class AzureSearch(VectorStore):
default_fields=default_fields,
user_agent=user_agent,
cors_options=cors_options,
additional_search_client_options=additional_search_client_options,
)
self.async_client = _get_search_client(
azure_search_endpoint,
azure_search_key,
index_name,
semantic_configuration_name=semantic_configuration_name,
fields=fields,
vector_search=vector_search,
semantic_configurations=semantic_configurations,
scoring_profiles=scoring_profiles,
default_scoring_profile=default_scoring_profile,
default_fields=default_fields,
user_agent=user_agent,
cors_options=cors_options,
async_=True,
)
self.search_type = search_type
self.semantic_configuration_name = semantic_configuration_name
@@ -338,23 +359,6 @@ class AzureSearch(VectorStore):
self._user_agent = user_agent
self._cors_options = cors_options
def _async_client(self) -> AsyncSearchClient:
return _get_search_client(
self._azure_search_endpoint,
self._azure_search_key,
self._index_name,
semantic_configuration_name=self._semantic_configuration_name,
fields=self._fields,
vector_search=self._vector_search,
semantic_configurations=self._semantic_configurations,
scoring_profiles=self._scoring_profiles,
default_scoring_profile=self._default_scoring_profile,
default_fields=self._default_fields,
user_agent=self._user_agent,
cors_options=self._cors_options,
async_=True,
)
@property
def embeddings(self) -> Optional[Embeddings]:
# TODO: Support embedding object directly
@@ -513,7 +517,7 @@ class AzureSearch(VectorStore):
ids.append(key)
# Upload data in batches
if len(data) == MAX_UPLOAD_BATCH_SIZE:
async with self._async_client() as async_client:
async with self.async_client as async_client:
response = await async_client.upload_documents(documents=data)
# Check if all documents were successfully uploaded
if not all(r.succeeded for r in response):
@@ -526,7 +530,7 @@ class AzureSearch(VectorStore):
return ids
# Upload data to index
async with self._async_client() as async_client:
async with self.async_client as async_client:
response = await async_client.upload_documents(documents=data)
# Check if all documents were successfully uploaded
if all(r.succeeded for r in response):
@@ -561,7 +565,7 @@ class AzureSearch(VectorStore):
False otherwise.
"""
if ids:
async with self._async_client() as async_client:
async with self.async_client as async_client:
res = await async_client.delete_documents([{"id": i} for i in ids])
return len(res) > 0
else:
@@ -739,11 +743,11 @@ class AzureSearch(VectorStore):
to the query and score for each
"""
embedding = await self._aembed_query(query)
docs, scores, _ = await self._asimple_search(
results = await self._asimple_search(
embedding, "", k, filters=filters, **kwargs
)
return list(zip(docs, scores))
return _results_to_documents(results)
def max_marginal_relevance_search_with_score(
self,
@@ -807,14 +811,12 @@ class AzureSearch(VectorStore):
to the query and score for each
"""
embedding = await self._aembed_query(query)
docs, scores, vectors = await self._asimple_search(
results = await self._asimple_search(
embedding, "", fetch_k, filters=filters, **kwargs
)
return await self._areorder_results_with_maximal_marginal_relevance(
docs,
scores,
vectors,
return await _areorder_results_with_maximal_marginal_relevance(
results,
query_embedding=np.array(embedding),
lambda_mult=lambda_mult,
k=k,
@@ -890,11 +892,11 @@ class AzureSearch(VectorStore):
"""
embedding = await self._aembed_query(query)
docs, scores, _ = await self._asimple_search(
results = await self._asimple_search(
embedding, query, k, filters=filters, **kwargs
)
return list(zip(docs, scores))
return _results_to_documents(results)
def hybrid_search_with_relevance_scores(
self,
@@ -992,14 +994,12 @@ class AzureSearch(VectorStore):
"""
embedding = await self._aembed_query(query)
docs, scores, vectors = await self._asimple_search(
results = await self._asimple_search(
embedding, query, fetch_k, filters=filters, **kwargs
)
return await self._areorder_results_with_maximal_marginal_relevance(
docs,
scores,
vectors,
return await _areorder_results_with_maximal_marginal_relevance(
results,
query_embedding=np.array(embedding),
lambda_mult=lambda_mult,
k=k,
@@ -1049,7 +1049,7 @@ class AzureSearch(VectorStore):
*,
filters: Optional[str] = None,
**kwargs: Any,
) -> Tuple[List[Document], List[float], List[List[float]]]:
) -> SearchItemPaged[dict]:
"""Perform vector or hybrid search in the Azure search index.
Args:
@@ -1063,8 +1063,8 @@ class AzureSearch(VectorStore):
"""
from azure.search.documents.models import VectorizedQuery
async with self._async_client() as async_client:
results = await async_client.search(
async with self.async_client as async_client:
return await async_client.search(
search_text=text_query,
vector_queries=[
VectorizedQuery(
@@ -1077,18 +1077,6 @@ class AzureSearch(VectorStore):
top=k,
**kwargs,
)
docs = [
(
_result_to_document(result),
float(result["@search.score"]),
result[FIELDS_CONTENT_VECTOR],
)
async for result in results
]
if not docs:
raise ValueError(f"No {docs=}")
documents, scores, vectors = map(list, zip(*docs))
return documents, scores, vectors
def semantic_hybrid_search(
self, query: str, k: int = 4, **kwargs: Any
@@ -1300,7 +1288,7 @@ class AzureSearch(VectorStore):
from azure.search.documents.models import VectorizedQuery
vector = await self._aembed_query(query)
async with self._async_client() as async_client:
async with self.async_client as async_client:
results = await async_client.search(
search_text=query,
vector_queries=[
@@ -1475,30 +1463,6 @@ class AzureSearch(VectorStore):
azure_search.add_embeddings(text_embeddings, metadatas, **kwargs)
return azure_search
async def _areorder_results_with_maximal_marginal_relevance(
self,
documents: List[Document],
scores: List[float],
vectors: List[List[float]],
query_embedding: np.ndarray,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[Tuple[Document, float]]:
# Get the new order of results.
new_ordering = maximal_marginal_relevance(
query_embedding, vectors, k=k, lambda_mult=lambda_mult
)
# Reorder the values and return.
ret: List[Tuple[Document, float]] = []
for x in new_ordering:
# Function can return -1 index
if x == -1:
break
ret.append((documents[x], scores[x])) # type: ignore
return ret
def as_retriever(self, **kwargs: Any) -> AzureSearchVectorStoreRetriever: # type: ignore
"""Return AzureSearchVectorStoreRetriever initialized from this VectorStore.
@@ -1666,6 +1630,39 @@ def _results_to_documents(
return docs
async def _areorder_results_with_maximal_marginal_relevance(
results: SearchItemPaged[Dict],
query_embedding: np.ndarray,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[Tuple[Document, float]]:
# Convert results to Document objects
docs = [
(
_result_to_document(result),
float(result["@search.score"]),
result[FIELDS_CONTENT_VECTOR],
)
for result in results
]
documents, scores, vectors = map(list, zip(*docs))
# Get the new order of results.
new_ordering = maximal_marginal_relevance(
query_embedding, vectors, k=k, lambda_mult=lambda_mult
)
# Reorder the values and return.
ret: List[Tuple[Document, float]] = []
for x in new_ordering:
# Function can return -1 index
if x == -1:
break
ret.append((documents[x], scores[x])) # type: ignore
return ret
def _reorder_results_with_maximal_marginal_relevance(
results: SearchItemPaged[Dict],
query_embedding: np.ndarray,

View File

@@ -9,7 +9,7 @@ from langchain_community.tools.zenguard.tool import Detector, ZenGuardTool
@pytest.fixture()
def zenguard_tool() -> ZenGuardTool:
if os.getenv("ZENGUARD_API_KEY") is None:
raise ValueError("ZENGUARD_API_KEY is not set in environment varibale")
raise ValueError("ZENGUARD_API_KEY is not set in environment variable")
return ZenGuardTool()

View File

@@ -12,7 +12,7 @@ PAGE_1 = """
Hello.
<a href="relative">Relative</a>
<a href="/relative-base">Relative base.</a>
<a href="http://cnn.com">Aboslute</a>
<a href="http://cnn.com">Absolute</a>
<a href="//same.foo">Test</a>
</body>
</html>

View File

@@ -39,7 +39,7 @@ def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]:
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
"""Test that tools defined in this repo accept a run manager argument."""
# This wouldn't be necessary if the BaseTool had a strict API.
if cls._run is not BaseTool._arun:
if cls._run is not BaseTool._run:
run_func = cls._run
params = inspect.signature(run_func).parameters
assert "run_manager" in params

View File

@@ -1,5 +1,5 @@
import json
from typing import List, Optional
from typing import Any, Dict, List, Optional
from unittest.mock import patch
import pytest
@@ -121,12 +121,15 @@ def mock_default_index(*args, **kwargs): # type: ignore[no-untyped-def]
)
def create_vector_store() -> AzureSearch:
def create_vector_store(
additional_search_client_options: Optional[Dict[str, Any]] = None,
) -> AzureSearch:
return AzureSearch(
azure_search_endpoint=DEFAULT_ENDPOINT,
azure_search_key=DEFAULT_KEY,
index_name=DEFAULT_INDEX_NAME,
embedding_function=DEFAULT_EMBEDDING_MODEL,
additional_search_client_options=additional_search_client_options,
)
@@ -168,3 +171,20 @@ def test_init_new_index() -> None:
assert json.dumps(created_index.as_dict()) == json.dumps(
mock_default_index().as_dict()
)
@pytest.mark.requires("azure.search.documents")
def test_additional_search_options() -> None:
from azure.search.documents.indexes import SearchIndexClient
def mock_create_index() -> None:
pytest.fail("Should not create index in this test")
with patch.multiple(
SearchIndexClient, get_index=mock_default_index, create_index=mock_create_index
):
vector_store = create_vector_store(
additional_search_client_options={"api_version": "test"}
)
assert vector_store.client is not None
assert vector_store.client._api_version == "test"

View File

@@ -15,7 +15,7 @@ PathLike = Union[str, PurePath]
class BaseMedia(Serializable):
"""Use to represent media content.
Media objets can be used to represent raw data, such as text or binary data.
Media objects can be used to represent raw data, such as text or binary data.
LangChain Media objects allow associating metadata and an optional identifier
with the content.

View File

@@ -12,6 +12,7 @@ from typing import (
Optional,
)
from langchain_core._api import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
@@ -130,11 +131,14 @@ def nodes_to_documents(nodes: Iterable[Node]) -> Iterator[Document]:
)
@beta(message="Added in version 0.2.14 of langchain_core. API subject to change.")
class GraphVectorStore(VectorStore):
"""A hybrid vector-and-graph graph store.
Document chunks support vector-similarity search as well as edges linking
chunks based on structural and semantic properties.
.. versionadded:: 0.2.14
"""
@abstractmethod

View File

@@ -15,11 +15,18 @@ from langchain_core.messages.tool import (
default_tool_chunk_parser,
default_tool_parser,
)
from langchain_core.messages.tool import (
invalid_tool_call as create_invalid_tool_call,
)
from langchain_core.messages.tool import (
tool_call as create_tool_call,
)
from langchain_core.messages.tool import (
tool_call_chunk as create_tool_call_chunk,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import (
parse_partial_json,
)
from langchain_core.utils.json import parse_partial_json
class UsageMetadata(TypedDict):
@@ -106,24 +113,55 @@ class AIMessage(BaseMessage):
@root_validator(pre=True)
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
tool_calls = (
values.get("tool_calls")
or values.get("invalid_tool_calls")
or values.get("tool_call_chunks")
check_additional_kwargs = not any(
values.get(k)
for k in ("tool_calls", "invalid_tool_calls", "tool_call_chunks")
)
if raw_tool_calls and not tool_calls:
if check_additional_kwargs and (
raw_tool_calls := values.get("additional_kwargs", {}).get("tool_calls")
):
try:
if issubclass(cls, AIMessageChunk): # type: ignore
values["tool_call_chunks"] = default_tool_chunk_parser(
raw_tool_calls
)
else:
tool_calls, invalid_tool_calls = default_tool_parser(raw_tool_calls)
values["tool_calls"] = tool_calls
values["invalid_tool_calls"] = invalid_tool_calls
parsed_tool_calls, parsed_invalid_tool_calls = default_tool_parser(
raw_tool_calls
)
values["tool_calls"] = parsed_tool_calls
values["invalid_tool_calls"] = parsed_invalid_tool_calls
except Exception:
pass
# Ensure "type" is properly set on all tool call-like dicts.
if tool_calls := values.get("tool_calls"):
updated: List = []
for tc in tool_calls:
updated.append(
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
)
values["tool_calls"] = updated
if invalid_tool_calls := values.get("invalid_tool_calls"):
updated = []
for tc in invalid_tool_calls:
updated.append(
create_invalid_tool_call(
**{k: v for k, v in tc.items() if k != "type"}
)
)
values["invalid_tool_calls"] = updated
if tool_call_chunks := values.get("tool_call_chunks"):
updated = []
for tc in tool_call_chunks:
updated.append(
create_tool_call_chunk(
**{k: v for k, v in tc.items() if k != "type"}
)
)
values["tool_call_chunks"] = updated
return values
def pretty_repr(self, html: bool = False) -> str:
@@ -216,7 +254,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
if not values["tool_call_chunks"]:
if values["tool_calls"]:
values["tool_call_chunks"] = [
ToolCallChunk(
create_tool_call_chunk(
name=tc["name"],
args=json.dumps(tc["args"]),
id=tc["id"],
@@ -228,7 +266,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
tool_call_chunks = values.get("tool_call_chunks", [])
tool_call_chunks.extend(
[
ToolCallChunk(
create_tool_call_chunk(
name=tc["name"], args=tc["args"], id=tc["id"], index=None
)
for tc in values["invalid_tool_calls"]
@@ -244,7 +282,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {}
if isinstance(args_, dict):
tool_calls.append(
ToolCall(
create_tool_call(
name=chunk["name"] or "",
args=args_,
id=chunk["id"],
@@ -254,7 +292,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
raise ValueError("Malformed args.")
except Exception:
invalid_tool_calls.append(
InvalidToolCall(
create_invalid_tool_call(
name=chunk["name"],
args=chunk["args"],
id=chunk["id"],
@@ -297,7 +335,7 @@ def add_ai_message_chunks(
left.tool_call_chunks, *(o.tool_call_chunks for o in others)
):
tool_call_chunks = [
ToolCallChunk(
create_tool_call_chunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),

View File

@@ -1,7 +1,7 @@
import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing_extensions import TypedDict
from typing_extensions import NotRequired, TypedDict
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
from langchain_core.utils._merge import merge_dicts, merge_obj
@@ -21,8 +21,11 @@ class ToolMessage(BaseMessage):
ToolMessage(content='42', tool_call_id='call_Jja7J89XsjrOLA5r!MEOW!SL')
Example: A ToolMessage where only part of the tool output is sent to the model
and the full output is passed in to raw_output.
and the full output is passed in to artifact.
.. versionadded:: 0.2.17
.. code-block:: python
@@ -36,7 +39,7 @@ class ToolMessage(BaseMessage):
ToolMessage(
content=tool_output["stdout"],
raw_output=tool_output,
artifact=tool_output,
tool_call_id='call_Jja7J89XsjrOLA5r!MEOW!SL',
)
@@ -54,12 +57,14 @@ class ToolMessage(BaseMessage):
type: Literal["tool"] = "tool"
"""The type of the message (used for serialization). Defaults to "tool"."""
raw_output: Any = None
"""The raw output of the tool.
artifact: Any = None
"""Artifact of the Tool execution which is not meant to be sent to the model.
**Not part of the payload sent to the model.** Should only be specified if it is
different from the message content, i.e. if only a subset of the full tool output
is being passed as message content.
Should only be specified if it is different from the message content, e.g. if only
a subset of the full tool output is being passed as message content but the full
output is needed in other parts of the code.
.. versionadded:: 0.2.17
"""
@classmethod
@@ -106,7 +111,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
return self.__class__(
tool_call_id=self.tool_call_id,
content=merge_content(self.content, other.content),
raw_output=merge_obj(self.raw_output, other.raw_output),
artifact=merge_obj(self.artifact, other.artifact),
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
@@ -146,6 +151,11 @@ class ToolCall(TypedDict):
An identifier is needed to associate a tool call request with a tool
call result in events when multiple concurrent tool calls are made.
"""
type: NotRequired[Literal["tool_call"]]
def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall:
return ToolCall(name=name, args=args, id=id, type="tool_call")
class ToolCallChunk(TypedDict):
@@ -176,6 +186,19 @@ class ToolCallChunk(TypedDict):
"""An identifier associated with the tool call."""
index: Optional[int]
"""The index of the tool call in a sequence."""
type: NotRequired[Literal["tool_call_chunk"]]
def tool_call_chunk(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None,
index: Optional[int] = None,
) -> ToolCallChunk:
return ToolCallChunk(
name=name, args=args, id=id, index=index, type="tool_call_chunk"
)
class InvalidToolCall(TypedDict):
@@ -193,6 +216,19 @@ class InvalidToolCall(TypedDict):
"""An identifier associated with the tool call."""
error: Optional[str]
"""An error message associated with the tool call."""
type: NotRequired[Literal["invalid_tool_call"]]
def invalid_tool_call(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None,
error: Optional[str] = None,
) -> InvalidToolCall:
return InvalidToolCall(
name=name, args=args, id=id, error=error, type="invalid_tool_call"
)
def default_tool_parser(
@@ -201,25 +237,25 @@ def default_tool_parser(
"""Best-effort parsing of tools."""
tool_calls = []
invalid_tool_calls = []
for tool_call in raw_tool_calls:
if "function" not in tool_call:
for raw_tool_call in raw_tool_calls:
if "function" not in raw_tool_call:
continue
else:
function_name = tool_call["function"]["name"]
function_name = raw_tool_call["function"]["name"]
try:
function_args = json.loads(tool_call["function"]["arguments"])
parsed = ToolCall(
function_args = json.loads(raw_tool_call["function"]["arguments"])
parsed = tool_call(
name=function_name or "",
args=function_args or {},
id=tool_call.get("id"),
id=raw_tool_call.get("id"),
)
tool_calls.append(parsed)
except json.JSONDecodeError:
invalid_tool_calls.append(
InvalidToolCall(
invalid_tool_call(
name=function_name,
args=tool_call["function"]["arguments"],
id=tool_call.get("id"),
args=raw_tool_call["function"]["arguments"],
id=raw_tool_call.get("id"),
error=None,
)
)
@@ -236,7 +272,7 @@ def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]
else:
function_args = tool_call["function"]["arguments"]
function_name = tool_call["function"]["name"]
parsed = ToolCallChunk(
parsed = tool_call_chunk(
name=function_name,
args=function_args,
id=tool_call.get("id"),

View File

@@ -16,6 +16,7 @@ from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
@@ -40,6 +41,7 @@ if TYPE_CHECKING:
from langchain_text_splitters import TextSplitter
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables.base import Runnable
AnyMessage = Union[
@@ -221,8 +223,8 @@ def _create_message_from_message_type(
elif message_type == "function":
message = FunctionMessage(content=content, **kwargs)
elif message_type == "tool":
raw_output = kwargs.get("additional_kwargs", {}).pop("raw_output", None)
message = ToolMessage(content=content, raw_output=raw_output, **kwargs)
artifact = kwargs.get("additional_kwargs", {}).pop("artifact", None)
message = ToolMessage(content=content, artifact=artifact, **kwargs)
elif message_type == "remove":
message = RemoveMessage(**kwargs)
else:
@@ -284,7 +286,7 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
def convert_to_messages(
messages: Sequence[MessageLikeRepresentation],
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
) -> List[BaseMessage]:
"""Convert a sequence of messages to a list of messages.
@@ -294,6 +296,11 @@ def convert_to_messages(
Returns:
List of messages (BaseMessages).
"""
# Import here to avoid circular imports
from langchain_core.prompt_values import PromptValue
if isinstance(messages, PromptValue):
return messages.to_messages()
return [_convert_to_message(m) for m in messages]
@@ -329,7 +336,7 @@ def _runnable_support(func: Callable) -> Callable:
@_runnable_support
def filter_messages(
messages: Sequence[MessageLikeRepresentation],
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
*,
include_names: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
@@ -417,7 +424,7 @@ def filter_messages(
@_runnable_support
def merge_message_runs(
messages: Sequence[MessageLikeRepresentation],
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
) -> List[BaseMessage]:
"""Merge consecutive Messages of the same type.
@@ -451,12 +458,12 @@ def merge_message_runs(
HumanMessage("wait your favorite food", id="bar",),
AIMessage(
"my favorite colo",
tool_calls=[ToolCall(name="blah_tool", args={"x": 2}, id="123")],
tool_calls=[ToolCall(name="blah_tool", args={"x": 2}, id="123", type="tool_call")],
id="baz",
),
AIMessage(
[{"type": "text", "text": "my favorite dish is lasagna"}],
tool_calls=[ToolCall(name="blah_tool", args={"x": -10}, id="456")],
tool_calls=[ToolCall(name="blah_tool", args={"x": -10}, id="456", type="tool_call")],
id="blur",
),
]
@@ -474,8 +481,8 @@ def merge_message_runs(
{"type": "text", "text": "my favorite dish is lasagna"}
],
tool_calls=[
ToolCall({"name": "blah_tool", "args": {"x": 2}, "id": "123"),
ToolCall({"name": "blah_tool", "args": {"x": -10}, "id": "456")
ToolCall({"name": "blah_tool", "args": {"x": 2}, "id": "123", "type": "tool_call"}),
ToolCall({"name": "blah_tool", "args": {"x": -10}, "id": "456", "type": "tool_call"})
]
id="baz"
),
@@ -506,7 +513,7 @@ def merge_message_runs(
@_runnable_support
def trim_messages(
messages: Sequence[MessageLikeRepresentation],
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
*,
max_tokens: int,
token_counter: Union[

View File

@@ -5,6 +5,12 @@ from typing import Any, Dict, List, Optional, Type
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages.tool import (
invalid_tool_call,
)
from langchain_core.messages.tool import (
tool_call as create_tool_call,
)
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
@@ -59,6 +65,7 @@ def parse_tool_call(
}
if return_id:
parsed["id"] = raw_tool_call.get("id")
parsed = create_tool_call(**parsed) # type: ignore
return parsed
@@ -75,7 +82,7 @@ def make_invalid_tool_call(
Returns:
An InvalidToolCall instance with the error message.
"""
return InvalidToolCall(
return invalid_tool_call(
name=raw_tool_call["function"]["name"],
args=raw_tool_call["function"]["arguments"],
id=raw_tool_call.get("id"),

File diff suppressed because it is too large Load Diff

View File

@@ -48,6 +48,10 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
If no condition evaluates to True, the default branch is run on the input.
Parameters:
branches: A list of (condition, Runnable) pairs.
default: A Runnable to run if no condition is met.
Examples:
.. code-block:: python
@@ -82,7 +86,18 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
RunnableLike, # To accommodate the default branch
],
) -> None:
"""A Runnable that runs one of two branches based on a condition."""
"""A Runnable that runs one of two branches based on a condition.
Args:
*branches: A list of (condition, Runnable) pairs.
Defaults a Runnable to run if no condition is met.
Raises:
ValueError: If the number of branches is less than 2.
TypeError: If the default branch is not Runnable, Callable or Mapping.
TypeError: If a branch is not a tuple or list.
ValueError: If a branch is not of length 2.
"""
if len(branches) < 2:
raise ValueError("RunnableBranch requires at least two branches")
@@ -93,7 +108,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
(Runnable, Callable, Mapping), # type: ignore[arg-type]
):
raise TypeError(
"RunnableBranch default must be runnable, callable or mapping."
"RunnableBranch default must be Runnable, callable or mapping."
)
default_ = cast(
@@ -176,7 +191,19 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""First evaluates the condition, then delegate to true or false branch."""
"""First evaluates the condition, then delegate to true or false branch.
Args:
input: The input to the Runnable.
config: The configuration for the Runnable. Defaults to None.
**kwargs: Additional keyword arguments to pass to the Runnable.
Returns:
The output of the branch that was run.
Raises:
"""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
@@ -277,7 +304,19 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
**kwargs: Optional[Any],
) -> Iterator[Output]:
"""First evaluates the condition,
then delegate to true or false branch."""
then delegate to true or false branch.
Args:
input: The input to the Runnable.
config: The configuration for the Runnable. Defaults to None.
**kwargs: Additional keyword arguments to pass to the Runnable.
Yields:
The output of the branch that was run.
Raises:
BaseException: If an error occurs during the execution of the Runnable.
"""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
@@ -352,7 +391,19 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
"""First evaluates the condition,
then delegate to true or false branch."""
then delegate to true or false branch.
Args:
input: The input to the Runnable.
config: The configuration for the Runnable. Defaults to None.
**kwargs: Additional keyword arguments to pass to the Runnable.
Yields:
The output of the branch that was run.
Raises:
BaseException: If an error occurs during the execution of the Runnable.
"""
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(

View File

@@ -111,7 +111,7 @@ var_child_runnable_config = ContextVar(
def _set_config_context(config: RunnableConfig) -> None:
"""Set the child runnable config + tracing context
"""Set the child Runnable config + tracing context
Args:
config (RunnableConfig): The config to set.
@@ -216,7 +216,6 @@ def patch_config(
Args:
config (Optional[RunnableConfig]): The config to patch.
copy_locals (bool, optional): Whether to copy locals. Defaults to False.
callbacks (Optional[BaseCallbackManager], optional): The callbacks to set.
Defaults to None.
recursion_limit (Optional[int], optional): The recursion limit to set.
@@ -362,9 +361,9 @@ def call_func_with_variable_args(
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output]]):
The function to call.
input (Input): The input to the function.
run_manager (CallbackManagerForChainRun): The run manager to
pass to the function.
config (RunnableConfig): The config to pass to the function.
run_manager (CallbackManagerForChainRun): The run manager to
pass to the function. Defaults to None.
**kwargs (Any): The keyword arguments to pass to the function.
Returns:
@@ -395,7 +394,7 @@ def acall_func_with_variable_args(
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**kwargs: Any,
) -> Awaitable[Output]:
"""Call function that may optionally accept a run_manager and/or config.
"""Async call function that may optionally accept a run_manager and/or config.
Args:
func (Union[Callable[[Input], Awaitable[Output]], Callable[[Input,
@@ -403,9 +402,9 @@ def acall_func_with_variable_args(
AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]]]):
The function to call.
input (Input): The input to the function.
run_manager (AsyncCallbackManagerForChainRun): The run manager
to pass to the function.
config (RunnableConfig): The config to pass to the function.
run_manager (AsyncCallbackManagerForChainRun): The run manager
to pass to the function. Defaults to None.
**kwargs (Any): The keyword arguments to pass to the function.
Returns:
@@ -493,6 +492,18 @@ class ContextThreadPoolExecutor(ThreadPoolExecutor):
timeout: float | None = None,
chunksize: int = 1,
) -> Iterator[T]:
"""Map a function to multiple iterables.
Args:
fn (Callable[..., T]): The function to map.
*iterables (Iterable[Any]): The iterables to map over.
timeout (float | None, optional): The timeout for the map.
Defaults to None.
chunksize (int, optional): The chunksize for the map. Defaults to 1.
Returns:
Iterator[T]: The iterator for the mapped function.
"""
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
def _wrapped_fn(*args: Any) -> T:
@@ -534,13 +545,16 @@ async def run_in_executor(
"""Run a function in an executor.
Args:
executor (Executor): The executor.
executor_or_config: The executor or config to run in.
func (Callable[P, Output]): The function.
*args (Any): The positional arguments to the function.
**kwargs (Any): The keyword arguments to the function.
Returns:
Output: The output of the function.
Raises:
RuntimeError: If the function raises a StopIteration.
"""
def wrapper() -> T:

View File

@@ -44,7 +44,15 @@ from langchain_core.runnables.utils import (
class DynamicRunnable(RunnableSerializable[Input, Output]):
"""Serializable Runnable that can be dynamically configured."""
"""Serializable Runnable that can be dynamically configured.
A DynamicRunnable should be initiated using the `configurable_fields` or
`configurable_alternatives` method of a Runnable.
Parameters:
default: The default Runnable to use.
config: The configuration to use.
"""
default: RunnableSerializable[Input, Output]
@@ -99,6 +107,15 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def prepare(
self, config: Optional[RunnableConfig] = None
) -> Tuple[Runnable[Input, Output], RunnableConfig]:
"""Prepare the Runnable for invocation.
Args:
config: The configuration to use. Defaults to None.
Returns:
Tuple[Runnable[Input, Output], RunnableConfig]: The prepared Runnable and
configuration.
"""
runnable: Runnable[Input, Output] = self
while isinstance(runnable, DynamicRunnable):
runnable, config = runnable._prepare(merge_configs(runnable.config, config))
@@ -284,6 +301,9 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
A RunnableConfigurableFields should be initiated using the
`configurable_fields` method of a Runnable.
Parameters:
fields: The configurable fields to use.
Here is an example of using a RunnableConfigurableFields with LLMs:
.. code-block:: python
@@ -348,6 +368,11 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
"""Get the configuration specs for the RunnableConfigurableFields.
Returns:
List[ConfigurableFieldSpec]: The configuration specs.
"""
return get_unique_config_specs(
[
(
@@ -374,6 +399,8 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
def configurable_fields(
self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]:
"""Get a new RunnableConfigurableFields with the specified
configurable fields."""
return self.default.configurable_fields(**{**self.fields, **kwargs})
def _prepare(
@@ -493,11 +520,13 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
""" # noqa: E501
which: ConfigurableField
"""The ConfigurableField to use to choose between alternatives."""
alternatives: Dict[
str,
Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
]
"""The alternatives to choose from."""
default_key: str = "default"
"""The enum value to use for the default option. Defaults to "default"."""
@@ -619,7 +648,7 @@ def prefix_config_spec(
prefix: The prefix to add.
Returns:
ConfigurableFieldSpec: The prefixed ConfigurableFieldSpec.
"""
return (
ConfigurableFieldSpec(
@@ -641,6 +670,13 @@ def make_options_spec(
) -> ConfigurableFieldSpec:
"""Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or
ConfigurableFieldMultiOption.
Args:
spec: The ConfigurableFieldSingleOption or ConfigurableFieldMultiOption.
description: The description to use if the spec does not have one.
Returns:
The ConfigurableFieldSpec.
"""
with _enums_for_spec_lock:
if enum := _enums_for_spec.get(spec):

View File

@@ -91,7 +91,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
"""
runnable: Runnable[Input, Output]
"""The runnable to run first."""
"""The Runnable to run first."""
fallbacks: Sequence[Runnable[Input, Output]]
"""A sequence of fallbacks to try."""
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
@@ -102,7 +102,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
exception_key: Optional[str] = None
"""If string is specified then handled exceptions will be passed to fallbacks as
part of the input under the specified key. If None, exceptions
will not be passed to fallbacks. If used, the base runnable and its fallbacks
will not be passed to fallbacks. If used, the base Runnable and its fallbacks
must accept a dictionary as input."""
class Config:
@@ -554,7 +554,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
await run_manager.on_chain_end(output)
def __getattr__(self, name: str) -> Any:
"""Get an attribute from the wrapped runnable and its fallbacks.
"""Get an attribute from the wrapped Runnable and its fallbacks.
Returns:
If the attribute is anything other than a method that outputs a Runnable,

View File

@@ -57,7 +57,14 @@ def is_uuid(value: str) -> bool:
class Edge(NamedTuple):
"""Edge in a graph."""
"""Edge in a graph.
Parameters:
source: The source node id.
target: The target node id.
data: Optional data associated with the edge. Defaults to None.
conditional: Whether the edge is conditional. Defaults to False.
"""
source: str
target: str
@@ -67,6 +74,15 @@ class Edge(NamedTuple):
def copy(
self, *, source: Optional[str] = None, target: Optional[str] = None
) -> Edge:
"""Return a copy of the edge with optional new source and target nodes.
Args:
source: The new source node id. Defaults to None.
target: The new target node id. Defaults to None.
Returns:
A copy of the edge with the new source and target nodes.
"""
return Edge(
source=source or self.source,
target=target or self.target,
@@ -76,7 +92,14 @@ class Edge(NamedTuple):
class Node(NamedTuple):
"""Node in a graph."""
"""Node in a graph.
Parameters:
id: The unique identifier of the node.
name: The name of the node.
data: The data of the node.
metadata: Optional metadata for the node. Defaults to None.
"""
id: str
name: str
@@ -84,6 +107,15 @@ class Node(NamedTuple):
metadata: Optional[Dict[str, Any]]
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
"""Return a copy of the node with optional new id and name.
Args:
id: The new node id. Defaults to None.
name: The new node name. Defaults to None.
Returns:
A copy of the node with the new id and name.
"""
return Node(
id=id or self.id,
name=name or self.name,
@@ -93,7 +125,13 @@ class Node(NamedTuple):
class Branch(NamedTuple):
"""Branch in a graph."""
"""Branch in a graph.
Parameters:
condition: A callable that returns a string representation of the condition.
ends: Optional dictionary of end node ids for the branches. Defaults
to None.
"""
condition: Callable[..., str]
ends: Optional[dict[str, str]]
@@ -117,12 +155,18 @@ class CurveStyle(Enum):
@dataclass
class NodeColors:
"""Schema for Hexadecimal color codes for different node types"""
class NodeStyles:
"""Schema for Hexadecimal color codes for different node types.
start: str = "#ffdfba"
end: str = "#baffc9"
other: str = "#fad7de"
Parameters:
default: The default color code. Defaults to "fill:#f2f0ff,line-height:1.2".
first: The color code for the first node. Defaults to "fill-opacity:0".
last: The color code for the last node. Defaults to "fill:#bfb6fc".
"""
default: str = "fill:#f2f0ff,line-height:1.2"
first: str = "fill-opacity:0"
last: str = "fill:#bfb6fc"
class MermaidDrawMethod(Enum):
@@ -161,7 +205,7 @@ def node_data_json(
Args:
node: The node to convert.
with_schemas: Whether to include the schema of the data if
it is a Pydantic model.
it is a Pydantic model. Defaults to False.
Returns:
A dictionary with the type of the data and the data itself.
@@ -209,13 +253,26 @@ def node_data_json(
@dataclass
class Graph:
"""Graph of nodes and edges."""
"""Graph of nodes and edges.
Parameters:
nodes: Dictionary of nodes in the graph. Defaults to an empty dictionary.
edges: List of edges in the graph. Defaults to an empty list.
"""
nodes: Dict[str, Node] = field(default_factory=dict)
edges: List[Edge] = field(default_factory=list)
def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]:
"""Convert the graph to a JSON-serializable format."""
"""Convert the graph to a JSON-serializable format.
Args:
with_schemas: Whether to include the schemas of the nodes if they are
Pydantic models. Defaults to False.
Returns:
A dictionary with the nodes and edges of the graph.
"""
stable_node_ids = {
node.id: i if is_uuid(node.id) else node.id
for i, node in enumerate(self.nodes.values())
@@ -247,6 +304,8 @@ class Graph:
return bool(self.nodes)
def next_id(self) -> str:
"""Return a new unique node
identifier that can be used to add a node to the graph."""
return uuid4().hex
def add_node(
@@ -256,7 +315,19 @@ class Graph:
*,
metadata: Optional[Dict[str, Any]] = None,
) -> Node:
"""Add a node to the graph and return it."""
"""Add a node to the graph and return it.
Args:
data: The data of the node.
id: The id of the node. Defaults to None.
metadata: Optional metadata for the node. Defaults to None.
Returns:
The node that was added to the graph.
Raises:
ValueError: If a node with the same id already exists.
"""
if id is not None and id in self.nodes:
raise ValueError(f"Node with id {id} already exists")
id = id or self.next_id()
@@ -265,7 +336,11 @@ class Graph:
return node
def remove_node(self, node: Node) -> None:
"""Remove a node from the graph and all edges connected to it."""
"""Remove a node from the graph and all edges connected to it.
Args:
node: The node to remove.
"""
self.nodes.pop(node.id)
self.edges = [
edge
@@ -280,7 +355,20 @@ class Graph:
data: Optional[Stringifiable] = None,
conditional: bool = False,
) -> Edge:
"""Add an edge to the graph and return it."""
"""Add an edge to the graph and return it.
Args:
source: The source node of the edge.
target: The target node of the edge.
data: Optional data associated with the edge. Defaults to None.
conditional: Whether the edge is conditional. Defaults to False.
Returns:
The edge that was added to the graph.
Raises:
ValueError: If the source or target node is not in the graph.
"""
if source.id not in self.nodes:
raise ValueError(f"Source node {source.id} not in graph")
if target.id not in self.nodes:
@@ -295,7 +383,15 @@ class Graph:
self, graph: Graph, *, prefix: str = ""
) -> Tuple[Optional[Node], Optional[Node]]:
"""Add all nodes and edges from another graph.
Note this doesn't check for duplicates, nor does it connect the graphs."""
Note this doesn't check for duplicates, nor does it connect the graphs.
Args:
graph: The graph to add.
prefix: The prefix to add to the node ids. Defaults to "".
Returns:
A tuple of the first and last nodes of the subgraph.
"""
if all(is_uuid(node.id) for node in graph.nodes.values()):
prefix = ""
@@ -350,7 +446,7 @@ class Graph:
def first_node(self) -> Optional[Node]:
"""Find the single node that is not a target of any edge.
If there is no such node, or there are multiple, return None.
When drawing the graph this node would be the origin."""
When drawing the graph, this node would be the origin."""
targets = {edge.target for edge in self.edges}
found: List[Node] = []
for node in self.nodes.values():
@@ -361,7 +457,7 @@ class Graph:
def last_node(self) -> Optional[Node]:
"""Find the single node that is not a source of any edge.
If there is no such node, or there are multiple, return None.
When drawing the graph this node would be the destination.
When drawing the graph, this node would be the destination.
"""
sources = {edge.source for edge in self.edges}
found: List[Node] = []
@@ -372,7 +468,7 @@ class Graph:
def trim_first_node(self) -> None:
"""Remove the first node if it exists and has a single outgoing edge,
ie. if removing it would not leave the graph without a "first" node."""
i.e., if removing it would not leave the graph without a "first" node."""
first_node = self.first_node()
if first_node:
if (
@@ -384,7 +480,7 @@ class Graph:
def trim_last_node(self) -> None:
"""Remove the last node if it exists and has a single incoming edge,
ie. if removing it would not leave the graph without a "last" node."""
i.e., if removing it would not leave the graph without a "last" node."""
last_node = self.last_node()
if last_node:
if (
@@ -395,6 +491,7 @@ class Graph:
self.remove_node(last_node)
def draw_ascii(self) -> str:
"""Draw the graph as an ASCII art string."""
from langchain_core.runnables.graph_ascii import draw_ascii
return draw_ascii(
@@ -403,6 +500,7 @@ class Graph:
)
def print_ascii(self) -> None:
"""Print the graph as an ASCII art string."""
print(self.draw_ascii()) # noqa: T201
@overload
@@ -427,6 +525,17 @@ class Graph:
fontname: Optional[str] = None,
labels: Optional[LabelsDict] = None,
) -> Union[bytes, None]:
"""Draw the graph as a PNG image.
Args:
output_file_path: The path to save the image to. If None, the image
is not saved. Defaults to None.
fontname: The name of the font to use. Defaults to None.
labels: Optional labels for nodes and edges in the graph. Defaults to None.
Returns:
The PNG image as bytes if output_file_path is None, None otherwise.
"""
from langchain_core.runnables.graph_png import PngDrawer
default_node_labels = {node.id: node.name for node in self.nodes.values()}
@@ -447,11 +556,21 @@ class Graph:
*,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(
start="#ffdfba", end="#baffc9", other="#fad7de"
),
node_colors: NodeStyles = NodeStyles(),
wrap_label_n_words: int = 9,
) -> str:
"""Draw the graph as a Mermaid syntax string.
Args:
with_styles: Whether to include styles in the syntax. Defaults to True.
curve_style: The style of the edges. Defaults to CurveStyle.LINEAR.
node_colors: The colors of the nodes. Defaults to NodeStyles().
wrap_label_n_words: The number of words to wrap the node labels at.
Defaults to 9.
Returns:
The Mermaid syntax string.
"""
from langchain_core.runnables.graph_mermaid import draw_mermaid
graph = self.reid()
@@ -465,7 +584,7 @@ class Graph:
last_node=last_node.id if last_node else None,
with_styles=with_styles,
curve_style=curve_style,
node_colors=node_colors,
node_styles=node_colors,
wrap_label_n_words=wrap_label_n_words,
)
@@ -473,15 +592,30 @@ class Graph:
self,
*,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(
start="#ffdfba", end="#baffc9", other="#fad7de"
),
node_colors: NodeStyles = NodeStyles(),
wrap_label_n_words: int = 9,
output_file_path: Optional[str] = None,
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
background_color: str = "white",
padding: int = 10,
) -> bytes:
"""Draw the graph as a PNG image using Mermaid.
Args:
curve_style: The style of the edges. Defaults to CurveStyle.LINEAR.
node_colors: The colors of the nodes. Defaults to NodeStyles().
wrap_label_n_words: The number of words to wrap the node labels at.
Defaults to 9.
output_file_path: The path to save the image to. If None, the image
is not saved. Defaults to None.
draw_method: The method to use to draw the graph.
Defaults to MermaidDrawMethod.API.
background_color: The color of the background. Defaults to "white".
padding: The padding around the graph. Defaults to 10.
Returns:
The PNG image as bytes.
"""
from langchain_core.runnables.graph_mermaid import draw_mermaid_png
mermaid_syntax = self.draw_mermaid(

View File

@@ -17,6 +17,7 @@ class VertexViewer:
"""
HEIGHT = 3 # top and bottom box edges + text
"""Height of the box."""
def __init__(self, name: str) -> None:
self._h = self.HEIGHT # top and bottom box edges + text

View File

@@ -8,7 +8,7 @@ from langchain_core.runnables.graph import (
Edge,
MermaidDrawMethod,
Node,
NodeColors,
NodeStyles,
)
@@ -20,21 +20,28 @@ def draw_mermaid(
last_node: Optional[str] = None,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(),
node_styles: NodeStyles = NodeStyles(),
wrap_label_n_words: int = 9,
) -> str:
"""Draws a Mermaid graph using the provided graph data
"""Draws a Mermaid graph using the provided graph data.
Args:
nodes (dict[str, str]): List of node ids
edges (List[Edge]): List of edges, object with source,
target and data.
nodes (dict[str, str]): List of node ids.
edges (List[Edge]): List of edges, object with a source,
target and data.
first_node (str, optional): Id of the first node. Defaults to None.
last_node (str, optional): Id of the last node. Defaults to None.
with_styles (bool, optional): Whether to include styles in the graph.
Defaults to True.
curve_style (CurveStyle, optional): Curve style for the edges.
node_colors (NodeColors, optional): Node colors for different types.
Defaults to CurveStyle.LINEAR.
node_styles (NodeStyles, optional): Node colors for different types.
Defaults to NodeStyles().
wrap_label_n_words (int, optional): Words to wrap the edge labels.
Defaults to 9.
Returns:
str: Mermaid graph syntax
str: Mermaid graph syntax.
"""
# Initialize Mermaid graph configuration
mermaid_graph = (
@@ -49,23 +56,27 @@ def draw_mermaid(
if with_styles:
# Node formatting templates
default_class_label = "default"
format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
format_dict = {default_class_label: "{0}({1})"}
if first_node is not None:
format_dict[first_node] = "{0}[{0}]:::startclass"
format_dict[first_node] = "{0}([{0}]):::first"
if last_node is not None:
format_dict[last_node] = "{0}[{0}]:::endclass"
format_dict[last_node] = "{0}([{0}]):::last"
# Add nodes to the graph
for key, node in nodes.items():
label = node.name.split(":")[-1]
if node.metadata:
label = f"<strong>{label}</strong>\n" + "\n".join(
f"{key} = {value}" for key, value in node.metadata.items()
label = (
f"{label}<hr/><small><em>"
+ "\n".join(
f"{key} = {value}" for key, value in node.metadata.items()
)
+ "</em></small>"
)
node_label = format_dict.get(key, format_dict[default_class_label]).format(
_escape_node_label(key), label
)
mermaid_graph += f"\t{node_label};\n"
mermaid_graph += f"\t{node_label}\n"
subgraph = ""
# Add edges to the graph
@@ -89,16 +100,14 @@ def draw_mermaid(
words = str(edge_data).split() # Split the string into words
# Group words into chunks of wrap_label_n_words size
if len(words) > wrap_label_n_words:
edge_data = "<br>".join(
[
" ".join(words[i : i + wrap_label_n_words])
for i in range(0, len(words), wrap_label_n_words)
]
edge_data = "&nbsp<br>&nbsp".join(
" ".join(words[i : i + wrap_label_n_words])
for i in range(0, len(words), wrap_label_n_words)
)
if edge.conditional:
edge_label = f" -. {edge_data} .-> "
edge_label = f" -. &nbsp{edge_data}&nbsp .-> "
else:
edge_label = f" -- {edge_data} --> "
edge_label = f" -- &nbsp{edge_data}&nbsp --> "
else:
if edge.conditional:
edge_label = " -.-> "
@@ -113,7 +122,7 @@ def draw_mermaid(
# Add custom styles for nodes
if with_styles:
mermaid_graph += _generate_mermaid_graph_styles(node_colors)
mermaid_graph += _generate_mermaid_graph_styles(node_styles)
return mermaid_graph
@@ -122,11 +131,11 @@ def _escape_node_label(node_label: str) -> str:
return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label)
def _generate_mermaid_graph_styles(node_colors: NodeColors) -> str:
def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:
"""Generates Mermaid graph styles for different node types."""
styles = ""
for class_name, color in asdict(node_colors).items():
styles += f"\tclassDef {class_name}class fill:{color};\n"
for class_name, style in asdict(node_colors).items():
styles += f"\tclassDef {class_name} {style}\n"
return styles
@@ -137,7 +146,24 @@ def draw_mermaid_png(
background_color: Optional[str] = "white",
padding: int = 10,
) -> bytes:
"""Draws a Mermaid graph as PNG using provided syntax."""
"""Draws a Mermaid graph as PNG using provided syntax.
Args:
mermaid_syntax (str): Mermaid graph syntax.
output_file_path (str, optional): Path to save the PNG image.
Defaults to None.
draw_method (MermaidDrawMethod, optional): Method to draw the graph.
Defaults to MermaidDrawMethod.API.
background_color (str, optional): Background color of the image.
Defaults to "white".
padding (int, optional): Padding around the image. Defaults to 10.
Returns:
bytes: PNG image bytes.
Raises:
ValueError: If an invalid draw method is provided.
"""
if draw_method == MermaidDrawMethod.PYPPETEER:
import asyncio

View File

@@ -6,7 +6,7 @@ from langchain_core.runnables.graph import Graph, LabelsDict
class PngDrawer:
"""Helper class to draw a state graph into a PNG file.
It requires graphviz and pygraphviz to be installed.
It requires `graphviz` and `pygraphviz` to be installed.
:param fontname: The font to use for the labels
:param labels: A dictionary of label overrides. The dictionary
should have the following format:
@@ -33,7 +33,7 @@ class PngDrawer:
"""Initializes the PNG drawer.
Args:
fontname: The font to use for the labels
fontname: The font to use for the labels. Defaults to "arial".
labels: A dictionary of label overrides. The dictionary
should have the following format:
{
@@ -48,6 +48,7 @@ class PngDrawer:
}
}
The keys are the original labels, and the values are the new labels.
Defaults to None.
"""
self.fontname = fontname or "arial"
self.labels = labels or LabelsDict(nodes={}, edges={})
@@ -56,7 +57,7 @@ class PngDrawer:
"""Returns the label to use for a node.
Args:
label: The original label
label: The original label.
Returns:
The new label.
@@ -68,7 +69,7 @@ class PngDrawer:
"""Returns the label to use for an edge.
Args:
label: The original label
label: The original label.
Returns:
The new label.
@@ -80,8 +81,8 @@ class PngDrawer:
"""Adds a node to the graph.
Args:
viz: The graphviz object
node: The node to add
viz: The graphviz object.
node: The node to add.
Returns:
None
@@ -106,9 +107,9 @@ class PngDrawer:
"""Adds an edge to the graph.
Args:
viz: The graphviz object
source: The source node
target: The target node
viz: The graphviz object.
source: The source node.
target: The target node.
label: The label for the edge. Defaults to None.
conditional: Whether the edge is conditional. Defaults to False.
@@ -127,7 +128,7 @@ class PngDrawer:
def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
"""Draw the given state graph into a PNG file.
Requires graphviz and pygraphviz to be installed.
Requires `graphviz` and `pygraphviz` to be installed.
:param graph: The graph to draw
:param output_path: The path to save the PNG. If None, PNG bytes are returned.
"""
@@ -156,14 +157,32 @@ class PngDrawer:
viz.close()
def add_nodes(self, viz: Any, graph: Graph) -> None:
"""Add nodes to the graph.
Args:
viz: The graphviz object.
graph: The graph to draw.
"""
for node in graph.nodes:
self.add_node(viz, node)
def add_edges(self, viz: Any, graph: Graph) -> None:
"""Add edges to the graph.
Args:
viz: The graphviz object.
graph: The graph to draw.
"""
for start, end, data, cond in graph.edges:
self.add_edge(viz, start, end, str(data), cond)
def update_styles(self, viz: Any, graph: Graph) -> None:
"""Update the styles of the entrypoint and END nodes.
Args:
viz: The graphviz object.
graph: The graph to draw.
"""
if first := graph.first_node():
viz.get_node(first.id).attr.update(fillcolor="lightblue")
if last := graph.last_node():

View File

@@ -45,13 +45,13 @@ class RunnableWithMessageHistory(RunnableBindingBase):
history for it; it is responsible for reading and updating the chat message
history.
The formats supports for the inputs and outputs of the wrapped Runnable
The formats supported for the inputs and outputs of the wrapped Runnable
are described below.
RunnableWithMessageHistory must always be called with a config that contains
the appropriate parameters for the chat message history factory.
By default the Runnable is expected to take a single configuration parameter
By default, the Runnable is expected to take a single configuration parameter
called `session_id` which is a string. This parameter is used to create a new
or look up an existing chat message history that matches the given session_id.
@@ -70,6 +70,19 @@ class RunnableWithMessageHistory(RunnableBindingBase):
For production use cases, you will want to use a persistent implementation
of chat message history, such as ``RedisChatMessageHistory``.
Parameters:
get_session_history: Function that returns a new BaseChatMessageHistory.
This function should either take a single positional argument
`session_id` of type string and return a corresponding
chat message history instance.
input_messages_key: Must be specified if the base runnable accepts a dict
as input. The key in the input dict that contains the messages.
output_messages_key: Must be specified if the base Runnable returns a dict
as output. The key in the output dict that contains the messages.
history_messages_key: Must be specified if the base runnable accepts a dict
as input and expects a separate key for historical messages.
history_factory_config: Configure fields that should be passed to the
chat history factory. See ``ConfigurableFieldSpec`` for more details.
Example: Chat message history with an in-memory implementation for testing.
@@ -287,9 +300,9 @@ class RunnableWithMessageHistory(RunnableBindingBase):
...
input_messages_key: Must be specified if the base runnable accepts a dict
as input.
as input. Default is None.
output_messages_key: Must be specified if the base runnable returns a dict
as output.
as output. Default is None.
history_messages_key: Must be specified if the base runnable accepts a dict
as input and expects a separate key for historical messages.
history_factory_config: Configure fields that should be passed to the
@@ -347,6 +360,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
"""Get the configuration specs for the RunnableWithMessageHistory."""
return get_unique_config_specs(
super().config_specs + list(self.history_factory_config)
)

View File

@@ -53,19 +53,33 @@ if TYPE_CHECKING:
def identity(x: Other) -> Other:
"""Identity function"""
"""Identity function.
Args:
x (Other): input.
Returns:
Other: output.
"""
return x
async def aidentity(x: Other) -> Other:
"""Async identity function"""
"""Async identity function.
Args:
x (Other): input.
Returns:
Other: output.
"""
return x
class RunnablePassthrough(RunnableSerializable[Other, Other]):
"""Runnable to passthrough inputs unchanged or with additional keys.
This runnable behaves almost like the identity function, except that it
This Runnable behaves almost like the identity function, except that it
can be configured to add additional keys to the output, if the input is a
dict.
@@ -73,6 +87,13 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
chains. The chains rely on simple lambdas to make the examples easy to execute
and experiment with.
Parameters:
func (Callable[[Other], None], optional): Function to be called with the input.
afunc (Callable[[Other], Awaitable[None]], optional): Async function to
be called with the input.
input_type (Optional[Type[Other]], optional): Type of the input.
**kwargs (Any): Additional keyword arguments.
Examples:
.. code-block:: python
@@ -199,10 +220,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
"""Merge the Dict input with the output produced by the mapping argument.
Args:
mapping: A mapping from keys to runnables or callables.
**kwargs: Runnable, Callable or a Mapping from keys to Runnables
or Callables.
Returns:
A runnable that merges the Dict input with the output produced by the
A Runnable that merges the Dict input with the output produced by the
mapping argument.
"""
return RunnableAssign(RunnableParallel(kwargs))
@@ -336,6 +358,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
these with the original data, introducing new key-value pairs based
on the mapper's logic.
Parameters:
mapper (RunnableParallel[Dict[str, Any]]): A `RunnableParallel` instance
that will be used to transform the input dictionary.
Examples:
.. code-block:: python
@@ -627,11 +653,15 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""Runnable that picks keys from Dict[str, Any] inputs.
RunnablePick class represents a runnable that selectively picks keys from a
RunnablePick class represents a Runnable that selectively picks keys from a
dictionary input. It allows you to specify one or more keys to extract
from the input dictionary. It returns a new dictionary containing only
the selected keys.
Parameters:
keys (Union[str, List[str]]): A single key or a list of keys to pick from
the input dictionary.
Example :
.. code-block:: python

View File

@@ -112,7 +112,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
"""Whether to add jitter to the exponential backoff."""
max_attempt_number: int = 3
"""The maximum number of attempts to retry the runnable."""
"""The maximum number of attempts to retry the Runnable."""
@classmethod
def get_lc_namespace(cls) -> List[str]:

View File

@@ -38,7 +38,7 @@ class RouterInput(TypedDict):
Attributes:
key: The key to route on.
input: The input to pass to the selected runnable.
input: The input to pass to the selected Runnable.
"""
key: str
@@ -50,6 +50,9 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
Runnable that routes to a set of Runnables based on Input['key'].
Returns the output of the selected Runnable.
Parameters:
runnables: A mapping of keys to Runnables.
For example,
.. code-block:: python

View File

@@ -1,4 +1,4 @@
"""Module contains typedefs that are used with runnables."""
"""Module contains typedefs that are used with Runnables."""
from __future__ import annotations
@@ -11,7 +11,7 @@ class EventData(TypedDict, total=False):
"""Data associated with a streaming event."""
input: Any
"""The input passed to the runnable that generated the event.
"""The input passed to the Runnable that generated the event.
Inputs will sometimes be available at the *START* of the Runnable, and
sometimes at the *END* of the Runnable.
@@ -85,40 +85,43 @@ class BaseStreamEvent(TypedDict):
event: str
"""Event names are of the format: on_[runnable_type]_(start|stream|end).
Runnable types are one of:
* llm - used by non chat models
* chat_model - used by chat models
* prompt -- e.g., ChatPromptTemplate
* tool -- from tools defined via @tool decorator or inheriting from Tool/BaseTool
* chain - most Runnables are of this type
Runnable types are one of:
- **llm** - used by non chat models
- **chat_model** - used by chat models
- **prompt** -- e.g., ChatPromptTemplate
- **tool** -- from tools defined via @tool decorator or inheriting
from Tool/BaseTool
- **chain** - most Runnables are of this type
Further, the events are categorized as one of:
* start - when the runnable starts
* stream - when the runnable is streaming
* end - when the runnable ends
- **start** - when the Runnable starts
- **stream** - when the Runnable is streaming
- **end* - when the Runnable ends
start, stream and end are associated with slightly different `data` payload.
Please see the documentation for `EventData` for more details.
"""
run_id: str
"""An randomly generated ID to keep track of the execution of the given runnable.
"""An randomly generated ID to keep track of the execution of the given Runnable.
Each child runnable that gets invoked as part of the execution of a parent runnable
Each child Runnable that gets invoked as part of the execution of a parent Runnable
is assigned its own unique ID.
"""
tags: NotRequired[List[str]]
"""Tags associated with the runnable that generated this event.
"""Tags associated with the Runnable that generated this event.
Tags are always inherited from parent runnables.
Tags are always inherited from parent Runnables.
Tags can either be bound to a runnable using `.with_config({"tags": ["hello"]})`
Tags can either be bound to a Runnable using `.with_config({"tags": ["hello"]})`
or passed at run time using `.astream_events(..., {"tags": ["hello"]})`.
"""
metadata: NotRequired[Dict[str, Any]]
"""Metadata associated with the runnable that generated this event.
"""Metadata associated with the Runnable that generated this event.
Metadata can either be bound to a runnable using
Metadata can either be bound to a Runnable using
`.with_config({"metadata": { "foo": "bar" }})`
@@ -150,21 +153,20 @@ class StandardStreamEvent(BaseStreamEvent):
The contents of the event data depend on the event type.
"""
name: str
"""The name of the runnable that generated the event."""
"""The name of the Runnable that generated the event."""
class CustomStreamEvent(BaseStreamEvent):
"""A custom stream event created by the user.
"""Custom stream event created by the user.
.. versionadded:: 0.2.14
"""
# Overwrite the event field to be more specific.
event: Literal["on_custom_event"] # type: ignore[misc]
"""The event type."""
name: str
"""A user defined name for the event."""
"""User defined name for the event."""
data: Any
"""The data associated with the event. Free form and can be anything."""

View File

@@ -43,6 +43,7 @@ Output = TypeVar("Output", covariant=True)
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
"""Run a coroutine with a semaphore.
Args:
semaphore: The semaphore to use.
coro: The coroutine to run.
@@ -59,7 +60,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
Args:
n: The number of coroutines to run concurrently.
coros: The coroutines to run.
*coros: The coroutines to run.
Returns:
The results of the coroutines.
@@ -73,7 +74,14 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
"""Check if a callable accepts a run_manager argument."""
"""Check if a callable accepts a run_manager argument.
Args:
callable: The callable to check.
Returns:
bool: True if the callable accepts a run_manager argument, False otherwise.
"""
try:
return signature(callable).parameters.get("run_manager") is not None
except ValueError:
@@ -81,7 +89,14 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
def accepts_config(callable: Callable[..., Any]) -> bool:
"""Check if a callable accepts a config argument."""
"""Check if a callable accepts a config argument.
Args:
callable: The callable to check.
Returns:
bool: True if the callable accepts a config argument, False otherwise.
"""
try:
return signature(callable).parameters.get("config") is not None
except ValueError:
@@ -89,7 +104,14 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
def accepts_context(callable: Callable[..., Any]) -> bool:
"""Check if a callable accepts a context argument."""
"""Check if a callable accepts a context argument.
Args:
callable: The callable to check.
Returns:
bool: True if the callable accepts a context argument, False otherwise.
"""
try:
return signature(callable).parameters.get("context") is not None
except ValueError:
@@ -100,10 +122,24 @@ class IsLocalDict(ast.NodeVisitor):
"""Check if a name is a local dict."""
def __init__(self, name: str, keys: Set[str]) -> None:
"""Initialize the visitor.
Args:
name: The name to check.
keys: The keys to populate.
"""
self.name = name
self.keys = keys
def visit_Subscript(self, node: ast.Subscript) -> Any:
"""Visit a subscript node.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
if (
isinstance(node.ctx, ast.Load)
and isinstance(node.value, ast.Name)
@@ -115,6 +151,14 @@ class IsLocalDict(ast.NodeVisitor):
self.keys.add(node.slice.value)
def visit_Call(self, node: ast.Call) -> Any:
"""Visit a call node.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
if (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
@@ -135,18 +179,42 @@ class IsFunctionArgDict(ast.NodeVisitor):
self.keys: Set[str] = set()
def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
if not node.args.args:
return
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node.body)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Visit a function definition.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
if not node.args.args:
return
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
"""Visit an async function definition.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
if not node.args.args:
return
input_arg_name = node.args.args[0].arg
@@ -161,12 +229,28 @@ class NonLocals(ast.NodeVisitor):
self.stores: Set[str] = set()
def visit_Name(self, node: ast.Name) -> Any:
"""Visit a name node.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
if isinstance(node.ctx, ast.Load):
self.loads.add(node.id)
elif isinstance(node.ctx, ast.Store):
self.stores.add(node.id)
def visit_Attribute(self, node: ast.Attribute) -> Any:
"""Visit an attribute node.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
if isinstance(node.ctx, ast.Load):
parent = node.value
attr_expr = node.attr
@@ -185,16 +269,40 @@ class FunctionNonLocals(ast.NodeVisitor):
self.nonlocals: Set[str] = set()
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Visit a function definition.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
visitor = NonLocals()
visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
"""Visit an async function definition.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
visitor = NonLocals()
visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores)
def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
visitor = NonLocals()
visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores)
@@ -209,14 +317,29 @@ class GetLambdaSource(ast.NodeVisitor):
self.count = 0
def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function."""
"""Visit a lambda function.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
self.count += 1
if hasattr(ast, "unparse"):
self.source = ast.unparse(node)
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
"""Get the keys of the first argument of a function if it is a dict."""
"""Get the keys of the first argument of a function if it is a dict.
Args:
func: The function to check.
Returns:
Optional[List[str]]: The keys of the first argument if it is a dict,
None otherwise.
"""
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
@@ -231,10 +354,10 @@ def get_lambda_source(func: Callable) -> Optional[str]:
"""Get the source code of a lambda function.
Args:
func: a callable that can be a lambda function
func: a Callable that can be a lambda function.
Returns:
str: the source code of the lambda function
str: the source code of the lambda function.
"""
try:
name = func.__name__ if func.__name__ != "<lambda>" else None
@@ -251,7 +374,14 @@ def get_lambda_source(func: Callable) -> Optional[str]:
def get_function_nonlocals(func: Callable) -> List[Any]:
"""Get the nonlocal variables accessed by a function."""
"""Get the nonlocal variables accessed by a function.
Args:
func: The function to check.
Returns:
List[Any]: The nonlocal variables accessed by the function.
"""
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
@@ -283,11 +413,11 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
"""Indent all lines of text after the first line.
Args:
text: The text to indent
prefix: Used to determine the number of spaces to indent
text: The text to indent.
prefix: Used to determine the number of spaces to indent.
Returns:
str: The indented text
str: The indented text.
"""
n_spaces = len(prefix)
spaces = " " * n_spaces
@@ -341,7 +471,14 @@ Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
def add(addables: Iterable[Addable]) -> Optional[Addable]:
"""Add a sequence of addable objects together."""
"""Add a sequence of addable objects together.
Args:
addables: The addable objects to add.
Returns:
Optional[Addable]: The result of adding the addable objects.
"""
final = None
for chunk in addables:
if final is None:
@@ -352,7 +489,14 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]:
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
"""Asynchronously add a sequence of addable objects together."""
"""Asynchronously add a sequence of addable objects together.
Args:
addables: The addable objects to add.
Returns:
Optional[Addable]: The result of adding the addable objects.
"""
final = None
async for chunk in addables:
if final is None:
@@ -363,7 +507,15 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
class ConfigurableField(NamedTuple):
"""Field that can be configured by the user."""
"""Field that can be configured by the user.
Parameters:
id: The unique identifier of the field.
name: The name of the field. Defaults to None.
description: The description of the field. Defaults to None.
annotation: The annotation of the field. Defaults to None.
is_shared: Whether the field is shared. Defaults to False.
"""
id: str
@@ -377,7 +529,16 @@ class ConfigurableField(NamedTuple):
class ConfigurableFieldSingleOption(NamedTuple):
"""Field that can be configured by the user with a default value."""
"""Field that can be configured by the user with a default value.
Parameters:
id: The unique identifier of the field.
options: The options for the field.
default: The default value for the field.
name: The name of the field. Defaults to None.
description: The description of the field. Defaults to None.
is_shared: Whether the field is shared. Defaults to False.
"""
id: str
options: Mapping[str, Any]
@@ -392,7 +553,16 @@ class ConfigurableFieldSingleOption(NamedTuple):
class ConfigurableFieldMultiOption(NamedTuple):
"""Field that can be configured by the user with multiple default values."""
"""Field that can be configured by the user with multiple default values.
Parameters:
id: The unique identifier of the field.
options: The options for the field.
default: The default values for the field.
name: The name of the field. Defaults to None.
description: The description of the field. Defaults to None.
is_shared: Whether the field is shared. Defaults to False.
"""
id: str
options: Mapping[str, Any]
@@ -412,7 +582,17 @@ AnyConfigurableField = Union[
class ConfigurableFieldSpec(NamedTuple):
"""Field that can be configured by the user. It is a specification of a field."""
"""Field that can be configured by the user. It is a specification of a field.
Parameters:
id: The unique identifier of the field.
annotation: The annotation of the field.
name: The name of the field. Defaults to None.
description: The description of the field. Defaults to None.
default: The default value for the field. Defaults to None.
is_shared: Whether the field is shared. Defaults to False.
dependencies: The dependencies of the field. Defaults to None.
"""
id: str
annotation: Any
@@ -427,7 +607,17 @@ class ConfigurableFieldSpec(NamedTuple):
def get_unique_config_specs(
specs: Iterable[ConfigurableFieldSpec],
) -> List[ConfigurableFieldSpec]:
"""Get the unique config specs from a sequence of config specs."""
"""Get the unique config specs from a sequence of config specs.
Args:
specs: The config specs.
Returns:
List[ConfigurableFieldSpec]: The unique config specs.
Raises:
ValueError: If the runnable sequence contains conflicting config specs.
"""
grouped = groupby(
sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id
)
@@ -542,7 +732,15 @@ def _create_model_cached(
def is_async_generator(
func: Any,
) -> TypeGuard[Callable[..., AsyncIterator]]:
"""Check if a function is an async generator."""
"""Check if a function is an async generator.
Args:
func: The function to check.
Returns:
TypeGuard[Callable[..., AsyncIterator]: True if the function is
an async generator, False otherwise.
"""
return (
inspect.isasyncgenfunction(func)
or hasattr(func, "__call__")
@@ -553,7 +751,15 @@ def is_async_generator(
def is_async_callable(
func: Any,
) -> TypeGuard[Callable[..., Awaitable]]:
"""Check if a function is async."""
"""Check if a function is async.
Args:
func: The function to check.
Returns:
TypeGuard[Callable[..., Awaitable]: True if the function is async,
False otherwise.
"""
return (
asyncio.iscoroutinefunction(func)
or hasattr(func, "__call__")

File diff suppressed because it is too large Load Diff

View File

@@ -62,7 +62,21 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run."""
"""Start a trace for an LLM run.
Args:
serialized: The serialized model.
messages: The messages to start the chat with.
run_id: The run ID.
tags: The tags for the run. Defaults to None.
parent_run_id: The parent run ID. Defaults to None.
metadata: The metadata for the run. Defaults to None.
name: The name of the run.
**kwargs: Additional arguments.
Returns:
The run.
"""
chat_model_run = self._create_chat_model_run(
serialized=serialized,
messages=messages,
@@ -89,7 +103,21 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run."""
"""Start a trace for an LLM run.
Args:
serialized: The serialized model.
prompts: The prompts to start the LLM with.
run_id: The run ID.
tags: The tags for the run. Defaults to None.
parent_run_id: The parent run ID. Defaults to None.
metadata: The metadata for the run. Defaults to None.
name: The name of the run.
**kwargs: Additional arguments.
Returns:
The run.
"""
llm_run = self._create_llm_run(
serialized=serialized,
prompts=prompts,
@@ -113,7 +141,18 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Run:
"""Run on new LLM token. Only available when streaming is enabled."""
"""Run on new LLM token. Only available when streaming is enabled.
Args:
token: The token.
chunk: The chunk. Defaults to None.
run_id: The run ID.
parent_run_id: The parent run ID. Defaults to None.
**kwargs: Additional arguments.
Returns:
The run.
"""
# "chat_model" is only used for the experimental new streaming_events format.
# This change should not affect any existing tracers.
llm_run = self._llm_run_with_token_event(
@@ -133,6 +172,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Run on retry.
Args:
retry_state: The retry state.
run_id: The run ID.
**kwargs: Additional arguments.
Returns:
The run.
"""
llm_run = self._llm_run_with_retry_event(
retry_state=retry_state,
run_id=run_id,
@@ -140,7 +189,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
return llm_run
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for an LLM run."""
"""End a trace for an LLM run.
Args:
response: The response.
run_id: The run ID.
**kwargs: Additional arguments.
Returns:
The run.
"""
# "chat_model" is only used for the experimental new streaming_events format.
# This change should not affect any existing tracers.
llm_run = self._complete_llm_run(
@@ -158,7 +216,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Handle an error for an LLM run."""
"""Handle an error for an LLM run.
Args:
error: The error.
run_id: The run ID.
**kwargs: Additional arguments.
Returns:
The run.
"""
# "chat_model" is only used for the experimental new streaming_events format.
# This change should not affect any existing tracers.
llm_run = self._errored_llm_run(
@@ -182,7 +249,22 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for a chain run."""
"""Start a trace for a chain run.
Args:
serialized: The serialized chain.
inputs: The inputs for the chain.
run_id: The run ID.
tags: The tags for the run. Defaults to None.
parent_run_id: The parent run ID. Defaults to None.
metadata: The metadata for the run. Defaults to None.
run_type: The type of the run. Defaults to None.
name: The name of the run.
**kwargs: Additional arguments.
Returns:
The run.
"""
chain_run = self._create_chain_run(
serialized=serialized,
inputs=inputs,
@@ -206,7 +288,17 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Run:
"""End a trace for a chain run."""
"""End a trace for a chain run.
Args:
outputs: The outputs for the chain.
run_id: The run ID.
inputs: The inputs for the chain. Defaults to None.
**kwargs: Additional arguments.
Returns:
The run.
"""
chain_run = self._complete_chain_run(
outputs=outputs,
run_id=run_id,
@@ -225,7 +317,17 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Handle an error for a chain run."""
"""Handle an error for a chain run.
Args:
error: The error.
inputs: The inputs for the chain. Defaults to None.
run_id: The run ID.
**kwargs: Additional arguments.
Returns:
The run.
"""
chain_run = self._errored_chain_run(
error=error,
run_id=run_id,
@@ -249,7 +351,22 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for a tool run."""
"""Start a trace for a tool run.
Args:
serialized: The serialized tool.
input_str: The input string.
run_id: The run ID.
tags: The tags for the run. Defaults to None.
parent_run_id: The parent run ID. Defaults to None.
metadata: The metadata for the run. Defaults to None.
name: The name of the run.
inputs: The inputs for the tool.
**kwargs: Additional arguments.
Returns:
The run.
"""
tool_run = self._create_tool_run(
serialized=serialized,
input_str=input_str,
@@ -266,7 +383,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
return tool_run
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for a tool run."""
"""End a trace for a tool run.
Args:
output: The output for the tool.
run_id: The run ID.
**kwargs: Additional arguments.
Returns:
The run.
"""
tool_run = self._complete_tool_run(
output=output,
run_id=run_id,
@@ -283,7 +409,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Handle an error for a tool run."""
"""Handle an error for a tool run.
Args:
error: The error.
run_id: The run ID.
**kwargs: Additional arguments.
Returns:
The run.
"""
tool_run = self._errored_tool_run(
error=error,
run_id=run_id,
@@ -304,7 +439,21 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Run when Retriever starts running."""
"""Run when the Retriever starts running.
Args:
serialized: The serialized retriever.
query: The query.
run_id: The run ID.
parent_run_id: The parent run ID. Defaults to None.
tags: The tags for the run. Defaults to None.
metadata: The metadata for the run. Defaults to None.
name: The name of the run.
**kwargs: Additional arguments.
Returns:
The run.
"""
retrieval_run = self._create_retrieval_run(
serialized=serialized,
query=query,
@@ -326,7 +475,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
run_id: UUID,
**kwargs: Any,
) -> Run:
"""Run when Retriever errors."""
"""Run when Retriever errors.
Args:
error: The error.
run_id: The run ID.
**kwargs: Additional arguments.
Returns:
The run.
"""
retrieval_run = self._errored_retrieval_run(
error=error,
run_id=run_id,
@@ -339,7 +497,16 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> Run:
"""Run when Retriever ends running."""
"""Run when the Retriever ends running.
Args:
documents: The documents.
run_id: The run ID.
**kwargs: Additional arguments.
Returns:
The run.
"""
retrieval_run = self._complete_retrieval_run(
documents=documents,
run_id=run_id,

View File

@@ -68,8 +68,8 @@ def tracing_v2_enabled(
client (LangSmithClient, optional): The client of the langsmith.
Defaults to None.
Returns:
None
Yields:
LangChainTracer: The LangChain tracer.
Example:
>>> with tracing_v2_enabled():
@@ -100,7 +100,7 @@ def tracing_v2_enabled(
def collect_runs() -> Generator[RunCollectorCallbackHandler, None, None]:
"""Collect all run traces in context.
Returns:
Yields:
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
Example:

View File

@@ -46,7 +46,8 @@ SCHEMA_FORMAT_TYPE = Literal["original", "streaming_events"]
class _TracerCore(ABC):
"""
Abstract base class for tracers
Abstract base class for tracers.
This class provides common methods, and reusable methods for tracers.
"""
@@ -65,17 +66,18 @@ class _TracerCore(ABC):
Args:
_schema_format: Primarily changes how the inputs and outputs are
handled. For internal use only. This API will change.
- 'original' is the format used by all current tracers.
This format is slightly inconsistent with respect to inputs
and outputs.
This format is slightly inconsistent with respect to inputs
and outputs.
- 'streaming_events' is used for supporting streaming events,
for internal usage. It will likely change in the future, or
be deprecated entirely in favor of a dedicated async tracer
for streaming events.
for internal usage. It will likely change in the future, or
be deprecated entirely in favor of a dedicated async tracer
for streaming events.
- 'original+chat' is a format that is the same as 'original'
except it does NOT raise an attribute error on_chat_model_start
except it does NOT raise an attribute error on_chat_model_start
kwargs: Additional keyword arguments that will be passed to
the super class.
the superclass.
"""
super().__init__(**kwargs)
self._schema_format = _schema_format # For internal use only API will change.
@@ -207,7 +209,7 @@ class _TracerCore(ABC):
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Create a llm run"""
"""Create a llm run."""
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
@@ -234,7 +236,7 @@ class _TracerCore(ABC):
**kwargs: Any,
) -> Run:
"""
Append token event to LLM run and return the run
Append token event to LLM run and return the run.
"""
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
event_kwargs: Dict[str, Any] = {"token": token}
@@ -314,7 +316,7 @@ class _TracerCore(ABC):
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Create a chain Run"""
"""Create a chain Run."""
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})

View File

@@ -104,7 +104,7 @@ class EvaluatorCallbackHandler(BaseTracer):
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
"""Evaluate the run in the project.
Parameters
Args:
----------
run : Run
The run to be evaluated.
@@ -200,7 +200,7 @@ class EvaluatorCallbackHandler(BaseTracer):
def _persist_run(self, run: Run) -> None:
"""Run the evaluator on the run.
Parameters
Args:
----------
run : Run
The run to be evaluated.

View File

@@ -52,7 +52,18 @@ logger = logging.getLogger(__name__)
class RunInfo(TypedDict):
"""Information about a run."""
"""Information about a run.
This is used to keep track of the metadata associated with a run.
Parameters:
name: The name of the run.
tags: The tags associated with the run.
metadata: The metadata associated with the run.
run_type: The type of the run.
inputs: The inputs to the run.
parent_run_id: The ID of the parent run.
"""
name: str
tags: List[str]
@@ -150,7 +161,19 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
async def tap_output_aiter(
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
"""Tap the output aiter."""
"""Tap the output aiter.
This method is used to tap the output of a Runnable that produces
an async iterator. It is used to generate stream events for the
output of the Runnable.
Args:
run_id: The ID of the run.
output: The output of the Runnable.
Yields:
T: The output of the Runnable.
"""
sentinel = object()
# atomic check and set
tap = self.is_tapped.setdefault(run_id, sentinel)
@@ -192,7 +215,15 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
yield chunk
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Tap the output aiter."""
"""Tap the output aiter.
Args:
run_id: The ID of the run.
output: The output of the Runnable.
Yields:
T: The output of the Runnable.
"""
sentinel = object()
# atomic check and set
tap = self.is_tapped.setdefault(run_id, sentinel)

View File

@@ -32,7 +32,12 @@ _EXECUTOR: Optional[ThreadPoolExecutor] = None
def log_error_once(method: str, exception: Exception) -> None:
"""Log an error once."""
"""Log an error once.
Args:
method: The method that raised the exception.
exception: The exception that was raised.
"""
global _LOGGED
if (method, type(exception)) in _LOGGED:
return
@@ -82,7 +87,15 @@ class LangChainTracer(BaseTracer):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Initialize the LangChain tracer."""
"""Initialize the LangChain tracer.
Args:
example_id: The example ID.
project_name: The project name. Defaults to the tracer project.
client: The client. Defaults to the global client.
tags: The tags. Defaults to an empty list.
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
@@ -104,7 +117,21 @@ class LangChainTracer(BaseTracer):
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run."""
"""Start a trace for an LLM run.
Args:
serialized: The serialized model.
messages: The messages.
run_id: The run ID.
tags: The tags. Defaults to None.
parent_run_id: The parent run ID. Defaults to None.
metadata: The metadata. Defaults to None.
name: The name. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
Run: The run.
"""
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
@@ -130,7 +157,15 @@ class LangChainTracer(BaseTracer):
self.latest_run = run_
def get_run_url(self) -> str:
"""Get the LangSmith root run URL"""
"""Get the LangSmith root run URL.
Returns:
str: The LangSmith root run URL.
Raises:
ValueError: If no traced run is found.
ValueError: If the run URL cannot be found.
"""
if not self.latest_run:
raise ValueError("No traced run found.")
# If this is the first run in a project, the project may not yet be created.

View File

@@ -189,12 +189,15 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
handled.
**For internal use only. This API will change.**
- 'original' is the format used by all current tracers.
This format is slightly inconsistent with respect to inputs
and outputs.
This format is slightly inconsistent with respect to inputs
and outputs.
- 'streaming_events' is used for supporting streaming events,
for internal usage. It will likely change in the future, or
be deprecated entirely in favor of a dedicated async tracer
for streaming events.
for internal usage. It will likely change in the future, or
be deprecated entirely in favor of a dedicated async tracer
for streaming events.
Raises:
ValueError: If an invalid schema format is provided (internal use only).
"""
if _schema_format not in {"original", "streaming_events"}:
raise ValueError(
@@ -224,7 +227,15 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
return self.receive_stream.__aiter__()
def send(self, *ops: Dict[str, Any]) -> bool:
"""Send a patch to the stream, return False if the stream is closed."""
"""Send a patch to the stream, return False if the stream is closed.
Args:
*ops: The operations to send to the stream.
Returns:
bool: True if the patch was sent successfully, False if the stream
is closed.
"""
# We will likely want to wrap this in try / except at some point
# to handle exceptions that might arise at run time.
# For now we'll let the exception bubble up, and always return
@@ -235,7 +246,15 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
async def tap_output_aiter(
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
"""Tap an output async iterator to stream its values to the log."""
"""Tap an output async iterator to stream its values to the log.
Args:
run_id: The ID of the run.
output: The output async iterator.
Yields:
T: The output value.
"""
async for chunk in output:
# root run is handled in .astream_log()
if run_id != self.root_id:
@@ -254,7 +273,15 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
yield chunk
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Tap an output async iterator to stream its values to the log."""
"""Tap an output async iterator to stream its values to the log.
Args:
run_id: The ID of the run.
output: The output iterator.
Yields:
T: The output value.
"""
for chunk in output:
# root run is handled in .astream_log()
if run_id != self.root_id:
@@ -273,6 +300,14 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
yield chunk
def include_run(self, run: Run) -> bool:
"""Check if a Run should be included in the log.
Args:
run: The Run to check.
Returns:
bool: True if the run should be included, False otherwise.
"""
if run.id == self.root_id:
return False
@@ -454,7 +489,7 @@ def _get_standardized_inputs(
Returns:
Valid inputs are only dict. By conventions, inputs always represented
invocation using named arguments.
A None means that the input is not yet known!
None means that the input is not yet known!
"""
if schema_format == "original":
raise NotImplementedError(

View File

@@ -33,11 +33,27 @@ class _SendStream(Generic[T]):
self._done = done
async def send(self, item: T) -> None:
"""Schedule the item to be written to the queue using the original loop."""
"""Schedule the item to be written to the queue using the original loop.
This is a coroutine that can be awaited.
Args:
item: The item to write to the queue.
"""
return self.send_nowait(item)
def send_nowait(self, item: T) -> None:
"""Schedule the item to be written to the queue using the original loop."""
"""Schedule the item to be written to the queue using the original loop.
This is a non-blocking call.
Args:
item: The item to write to the queue.
Raises:
RuntimeError: If the event loop is already closed when trying to write
to the queue.
"""
try:
self._reader_loop.call_soon_threadsafe(self._queue.put_nowait, item)
except RuntimeError:
@@ -45,11 +61,18 @@ class _SendStream(Generic[T]):
raise # Raise the exception if the loop is not closed
async def aclose(self) -> None:
"""Schedule the done object write the queue using the original loop."""
"""Async schedule the done object write the queue using the original loop."""
return self.close()
def close(self) -> None:
"""Schedule the done object write the queue using the original loop."""
"""Schedule the done object write the queue using the original loop.
This is a non-blocking call.
Raises:
RuntimeError: If the event loop is already closed when trying to write
to the queue.
"""
try:
self._reader_loop.call_soon_threadsafe(self._queue.put_nowait, self._done)
except RuntimeError:
@@ -87,7 +110,7 @@ class _MemoryStream(Generic[T]):
This implementation is meant to be used with a single writer and a single reader.
This is an internal implementation to LangChain please do not use it directly.
This is an internal implementation to LangChain. Please do not use it directly.
"""
def __init__(self, loop: AbstractEventLoop) -> None:
@@ -103,11 +126,19 @@ class _MemoryStream(Generic[T]):
self._done = object()
def get_send_stream(self) -> _SendStream[T]:
"""Get a writer for the channel."""
"""Get a writer for the channel.
Returns:
_SendStream: The writer for the channel.
"""
return _SendStream[T](
reader_loop=self._loop, queue=self._queue, done=self._done
)
def get_receive_stream(self) -> _ReceiveStream[T]:
"""Get a reader for the channel."""
"""Get a reader for the channel.
Returns:
_ReceiveStream: The reader for the channel.
"""
return _ReceiveStream[T](queue=self._queue, done=self._done)

View File

@@ -16,7 +16,16 @@ AsyncListener = Union[
class RootListenersTracer(BaseTracer):
"""Tracer that calls listeners on run start, end, and error."""
"""Tracer that calls listeners on run start, end, and error.
Parameters:
log_missing_parent: Whether to log a warning if the parent is missing.
Default is False.
config: The runnable config.
on_start: The listener to call on run start.
on_end: The listener to call on run end.
on_error: The listener to call on run error.
"""
log_missing_parent = False
@@ -28,6 +37,14 @@ class RootListenersTracer(BaseTracer):
on_end: Optional[Listener],
on_error: Optional[Listener],
) -> None:
"""Initialize the tracer.
Args:
config: The runnable config.
on_start: The listener to call on run start.
on_end: The listener to call on run end.
on_error: The listener to call on run error
"""
super().__init__(_schema_format="original+chat")
self.config = config
@@ -63,7 +80,16 @@ class RootListenersTracer(BaseTracer):
class AsyncRootListenersTracer(AsyncBaseTracer):
"""Async Tracer that calls listeners on run start, end, and error."""
"""Async Tracer that calls listeners on run start, end, and error.
Parameters:
log_missing_parent: Whether to log a warning if the parent is missing.
Default is False.
config: The runnable config.
on_start: The listener to call on run start.
on_end: The listener to call on run end.
on_error: The listener to call on run error.
"""
log_missing_parent = False
@@ -75,6 +101,14 @@ class AsyncRootListenersTracer(AsyncBaseTracer):
on_end: Optional[AsyncListener],
on_error: Optional[AsyncListener],
) -> None:
"""Initialize the tracer.
Args:
config: The runnable config.
on_start: The listener to call on run start.
on_end: The listener to call on run end.
on_error: The listener to call on run error
"""
super().__init__(_schema_format="original+chat")
self.config = config

View File

@@ -8,13 +8,13 @@ from langchain_core.tracers.schemas import Run
class RunCollectorCallbackHandler(BaseTracer):
"""
Tracer that collects all nested runs in a list.
"""Tracer that collects all nested runs in a list.
This tracer is useful for inspection and evaluation purposes.
Parameters
----------
name : str, default="run-collector_callback_handler"
example_id : Optional[Union[UUID, str]], default=None
The ID of the example being traced. It can be either a UUID or a string.
"""
@@ -31,6 +31,8 @@ class RunCollectorCallbackHandler(BaseTracer):
----------
example_id : Optional[Union[UUID, str]], default=None
The ID of the example being traced. It can be either a UUID or a string.
**kwargs : Any
Additional keyword arguments
"""
super().__init__(**kwargs)
self.example_id = (

View File

@@ -112,7 +112,15 @@ class ToolRun(BaseRun):
class Run(BaseRunV2):
"""Run schema for the V2 API in the Tracer."""
"""Run schema for the V2 API in the Tracer.
Parameters:
child_runs: The child runs.
tags: The tags. Default is an empty list.
events: The events. Default is an empty list.
trace_id: The trace ID. Default is None.
dotted_order: The dotted order.
"""
child_runs: List[Run] = Field(default_factory=list)
tags: Optional[List[str]] = Field(default_factory=list)

View File

@@ -7,15 +7,14 @@ from langchain_core.utils.input import get_bolded_text, get_colored_text
def try_json_stringify(obj: Any, fallback: str) -> str:
"""
Try to stringify an object to JSON.
"""Try to stringify an object to JSON.
Args:
obj: Object to stringify.
fallback: Fallback string to return if the object cannot be stringified.
Returns:
A JSON string if the object can be stringified, otherwise the fallback string.
"""
try:
return json.dumps(obj, indent=2, ensure_ascii=False)
@@ -45,6 +44,8 @@ class FunctionCallbackHandler(BaseTracer):
"""Tracer that calls a function with a single str parameter."""
name: str = "function_callback_handler"
"""The name of the tracer. This is used to identify the tracer in the logs.
Default is "function_callback_handler"."""
def __init__(self, function: Callable[[str], None], **kwargs: Any) -> None:
super().__init__(**kwargs)
@@ -54,6 +55,14 @@ class FunctionCallbackHandler(BaseTracer):
pass
def get_parents(self, run: Run) -> List[Run]:
"""Get the parents of a run.
Args:
run: The run to get the parents of.
Returns:
A list of parent runs.
"""
parents = []
current_run = run
while current_run.parent_run_id:
@@ -66,6 +75,14 @@ class FunctionCallbackHandler(BaseTracer):
return parents
def get_breadcrumbs(self, run: Run) -> str:
"""Get the breadcrumbs of a run.
Args:
run: The run to get the breadcrumbs of.
Returns:
A string with the breadcrumbs of the run.
"""
parents = self.get_parents(run)[::-1]
string = " > ".join(
f"{parent.run_type}:{parent.name}"

View File

@@ -8,6 +8,17 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary.
Args:
left: The first dictionary to merge.
others: The other dictionaries to merge.
Returns:
The merged dictionary.
Raises:
TypeError: If the key exists in both dictionaries but has a different type.
TypeError: If the value has an unsupported type.
Example:
If left = {"function_call": {"arguments": None}} and
right = {"function_call": {"arguments": "{\n"}}
@@ -46,7 +57,15 @@ def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]
def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]:
"""Add many lists, handling None."""
"""Add many lists, handling None.
Args:
left: The first list to merge.
others: The other lists to merge.
Returns:
The merged list.
"""
merged = left.copy() if left is not None else None
for other in others:
if other is None:
@@ -75,6 +94,23 @@ def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]
def merge_obj(left: Any, right: Any) -> Any:
"""Merge two objects.
It handles specific scenarios where a key exists in both
dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary.
Args:
left: The first object to merge.
right: The other object to merge.
Returns:
The merged object.
Raises:
TypeError: If the key exists in both dictionaries but has a different type.
ValueError: If the two objects cannot be merged.
"""
if left is None or right is None:
return left if left is not None else right
elif type(left) is not type(right):

View File

@@ -44,6 +44,18 @@ def py_anext(
Can be used to compare the built-in implementation of the inner
coroutines machinery to C-implementation of __anext__() and send()
or throw() on the returned generator.
Args:
iterator: The async iterator to advance.
default: The value to return if the iterator is exhausted.
If not provided, a StopAsyncIteration exception is raised.
Returns:
The next value from the iterator, or the default value
if the iterator is exhausted.
Raises:
TypeError: If the iterator is not an async iterator.
"""
try:
@@ -71,7 +83,7 @@ def py_anext(
class NoLock:
"""Dummy lock that provides the proper interface but no protection"""
"""Dummy lock that provides the proper interface but no protection."""
async def __aenter__(self) -> None:
pass
@@ -88,7 +100,21 @@ async def tee_peer(
peers: List[Deque[T]],
lock: AsyncContextManager[Any],
) -> AsyncGenerator[T, None]:
"""An individual iterator of a :py:func:`~.tee`"""
"""An individual iterator of a :py:func:`~.tee`.
This function is a generator that yields items from the shared iterator
``iterator``. It buffers items until the least advanced iterator has
yielded them as well. The buffer is shared with all other peers.
Args:
iterator: The shared iterator.
buffer: The buffer for this peer.
peers: The buffers of all peers.
lock: The lock to synchronise access to the shared buffers.
Yields:
The next item from the shared iterator.
"""
try:
while True:
if not buffer:
@@ -204,6 +230,7 @@ class Tee(Generic[T]):
return False
async def aclose(self) -> None:
"""Async close all child iterators."""
for child in self._children:
await child.aclose()
@@ -258,7 +285,7 @@ async def abatch_iterate(
iterable: The async iterable to batch.
Returns:
An async iterator over the batches
An async iterator over the batches.
"""
batch: List[T] = []
async for element in iterable:

View File

@@ -1,42 +0,0 @@
import asyncio
import inspect
from functools import wraps
from typing import Any, Callable
def curry(func: Callable[..., Any], **curried_kwargs: Any) -> Callable[..., Any]:
"""Util that wraps a function and partially applies kwargs to it.
Returns a new function whose signature omits the curried variables.
Args:
func: The function to curry.
curried_kwargs: Arguments to apply to the function.
Returns:
A new function with curried arguments applied.
.. versionadded:: 0.2.18
"""
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
new_kwargs = {**curried_kwargs, **kwargs}
return await func(*args, **new_kwargs)
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
new_kwargs = {**curried_kwargs, **kwargs}
return func(*args, **new_kwargs)
sig = inspect.signature(func)
# Create a new signature without the curried parameters
new_params = [p for name, p in sig.parameters.items() if name not in curried_kwargs]
if asyncio.iscoroutinefunction(func):
async_wrapper = wraps(func)(async_wrapper)
setattr(async_wrapper, "__signature__", sig.replace(parameters=new_params))
return async_wrapper
else:
sync_wrapper = wraps(func)(sync_wrapper)
setattr(sync_wrapper, "__signature__", sig.replace(parameters=new_params))
return sync_wrapper

View File

@@ -36,7 +36,7 @@ def get_from_dict_or_env(
env_key: The environment variable to look up if the key is not
in the dictionary.
default: The default value to return if the key is not in the dictionary
or the environment.
or the environment. Defaults to None.
"""
if isinstance(key, (list, tuple)):
for k in key:
@@ -56,7 +56,22 @@ def get_from_dict_or_env(
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
"""Get a value from a dictionary or an environment variable."""
"""Get a value from a dictionary or an environment variable.
Args:
key: The key to look up in the dictionary.
env_key: The environment variable to look up if the key is not
in the dictionary.
default: The default value to return if the key is not in the dictionary
or the environment. Defaults to None.
Returns:
str: The value of the key.
Raises:
ValueError: If the key is not in the dictionary and no default value is
provided or if the environment variable is not set.
"""
if env_key in os.environ and os.environ[env_key]:
return os.environ[env_key]
elif default is not None:

View File

@@ -10,7 +10,19 @@ class StrictFormatter(Formatter):
def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str:
"""Check that no arguments are provided."""
"""Check that no arguments are provided.
Args:
format_string: The format string.
args: The arguments.
kwargs: The keyword arguments.
Returns:
The formatted string.
Raises:
ValueError: If any arguments are provided.
"""
if len(args) > 0:
raise ValueError(
"No arguments should be provided, "
@@ -21,6 +33,15 @@ class StrictFormatter(Formatter):
def validate_input_variables(
self, format_string: str, input_variables: List[str]
) -> None:
"""Check that all input variables are used in the format string.
Args:
format_string: The format string.
input_variables: The input variables.
Raises:
ValueError: If any input variables are not used in the format string.
"""
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs)

View File

@@ -55,7 +55,9 @@ class ToolDescription(TypedDict):
"""Representation of a callable function to the OpenAI API."""
type: Literal["function"]
"""The type of the tool."""
function: FunctionDescription
"""The function description."""
def _rm_titles(kv: dict, prev_key: str = "") -> dict:
@@ -85,7 +87,19 @@ def convert_pydantic_to_openai_function(
description: Optional[str] = None,
rm_titles: bool = True,
) -> FunctionDescription:
"""Converts a Pydantic model to a function description for the OpenAI API."""
"""Converts a Pydantic model to a function description for the OpenAI API.
Args:
model: The Pydantic model to convert.
name: The name of the function. If not provided, the title of the schema will be
used.
description: The description of the function. If not provided, the description
of the schema will be used.
rm_titles: Whether to remove titles from the schema. Defaults to True.
Returns:
The function description.
"""
schema = dereference_refs(model.schema())
schema.pop("definitions", None)
title = schema.pop("title", "")
@@ -108,7 +122,18 @@ def convert_pydantic_to_openai_tool(
name: Optional[str] = None,
description: Optional[str] = None,
) -> ToolDescription:
"""Converts a Pydantic model to a function description for the OpenAI API."""
"""Converts a Pydantic model to a function description for the OpenAI API.
Args:
model: The Pydantic model to convert.
name: The name of the function. If not provided, the title of the schema will be
used.
description: The description of the function. If not provided, the description
of the schema will be used.
Returns:
The tool description.
"""
function = convert_pydantic_to_openai_function(
model, name=name, description=description
)
@@ -133,12 +158,22 @@ def convert_python_function_to_openai_function(
Assumes the Python function has type hints and a docstring with a description. If
the docstring has Google Python style argument descriptions, these will be
included as well.
Args:
function: The Python function to convert.
Returns:
The OpenAI function description.
"""
from langchain_core import tools
func_name = _get_python_function_name(function)
model = tools.create_schema_from_function(
func_name, function, filter_args=(), parse_docstring=True
func_name,
function,
filter_args=(),
parse_docstring=True,
error_on_invalid_docstring=False,
)
return convert_pydantic_to_openai_function(
model,
@@ -153,7 +188,14 @@ def convert_python_function_to_openai_function(
removal="0.3.0",
)
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"""Format tool into the OpenAI function API."""
"""Format tool into the OpenAI function API.
Args:
tool: The tool to format.
Returns:
The function description.
"""
if tool.args_schema:
return convert_pydantic_to_openai_function(
tool.args_schema, name=tool.name, description=tool.description
@@ -183,7 +225,14 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
removal="0.3.0",
)
def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
"""Format tool into the OpenAI function API."""
"""Format tool into the OpenAI function API.
Args:
tool: The tool to format.
Returns:
The tool description.
"""
function = format_tool_to_openai_function(tool)
return {"type": "function", "function": function}
@@ -202,6 +251,9 @@ def convert_to_openai_function(
Returns:
A dict version of the passed in function which is compatible with the
OpenAI function-calling API.
Raises:
ValueError: If the function is not in a supported format.
"""
from langchain_core.tools import BaseTool
@@ -280,7 +332,7 @@ def tool_example_to_messages(
BaseModels
tool_outputs: Optional[List[str]], a list of tool call outputs.
Does not need to be provided. If not provided, a placeholder value
will be inserted.
will be inserted. Defaults to None.
Returns:
A list of messages

View File

@@ -34,11 +34,11 @@ DEFAULT_LINK_REGEX = (
def find_all_links(
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
) -> List[str]:
"""Extract all links from a raw html string.
"""Extract all links from a raw HTML string.
Args:
raw_html: original html.
pattern: Regex to use for extracting links from raw html.
raw_html: original HTML.
pattern: Regex to use for extracting links from raw HTML.
Returns:
List[str]: all links
@@ -57,20 +57,20 @@ def extract_sub_links(
exclude_prefixes: Sequence[str] = (),
continue_on_failure: bool = False,
) -> List[str]:
"""Extract all links from a raw html string and convert into absolute paths.
"""Extract all links from a raw HTML string and convert into absolute paths.
Args:
raw_html: original html.
url: the url of the html.
base_url: the base url to check for outside links against.
pattern: Regex to use for extracting links from raw html.
raw_html: original HTML.
url: the url of the HTML.
base_url: the base URL to check for outside links against.
pattern: Regex to use for extracting links from raw HTML.
prevent_outside: If True, ignore external links which are not children
of the base url.
of the base URL.
exclude_prefixes: Exclude any URLs that start with one of these prefixes.
continue_on_failure: If True, continue if parsing a specific link raises an
exception. Otherwise, raise the exception.
Returns:
List[str]: sub links
List[str]: sub links.
"""
base_url_to_use = base_url if base_url is not None else url
parsed_base_url = urlparse(base_url_to_use)

View File

@@ -3,12 +3,27 @@ import mimetypes
def encode_image(image_path: str) -> str:
"""Get base64 string from image URI."""
"""Get base64 string from image URI.
Args:
image_path: The path to the image.
Returns:
The base64 string of the image.
"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def image_to_data_url(image_path: str) -> str:
"""Get data URL from image URI.
Args:
image_path: The path to the image.
Returns:
The data URL of the image.
"""
encoding = encode_image(image_path)
mime_type = mimetypes.guess_type(image_path)[0]
return f"data:{mime_type};base64,{encoding}"

View File

@@ -14,7 +14,15 @@ _TEXT_COLOR_MAPPING = {
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
"""Get mapping for items to a support color.
Args:
items: The items to map to colors.
excluded_colors: The colors to exclude.
Returns:
The mapping of items to colors.
"""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
@@ -23,20 +31,45 @@ def get_color_mapping(
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
"""Get colored text.
Args:
text: The text to color.
color: The color to use.
Returns:
The colored text.
"""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def get_bolded_text(text: str) -> str:
"""Get bolded text."""
"""Get bolded text.
Args:
text: The text to bold.
Returns:
The bolded text.
"""
return f"\033[1m{text}\033[0m"
def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
"""Print text with highlighting and no end characters."""
"""Print text with highlighting and no end characters.
If a color is provided, the text will be printed in that color.
If a file is provided, the text will be written to that file.
Args:
text: The text to print.
color: The color to use. Defaults to None.
end: The end character to use. Defaults to "".
file: The file to write to. Defaults to None.
"""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:

View File

@@ -22,7 +22,7 @@ T = TypeVar("T")
class NoLock:
"""Dummy lock that provides the proper interface but no protection"""
"""Dummy lock that provides the proper interface but no protection."""
def __enter__(self) -> None:
pass
@@ -39,7 +39,21 @@ def tee_peer(
peers: List[Deque[T]],
lock: ContextManager[Any],
) -> Generator[T, None, None]:
"""An individual iterator of a :py:func:`~.tee`"""
"""An individual iterator of a :py:func:`~.tee`.
This function is a generator that yields items from the shared iterator
``iterator``. It buffers items until the least advanced iterator has
yielded them as well. The buffer is shared with all other peers.
Args:
iterator: The shared iterator.
buffer: The buffer for this peer.
peers: The buffers of all peers.
lock: The lock to synchronise access to the shared buffers.
Yields:
The next item from the shared iterator.
"""
try:
while True:
if not buffer:
@@ -118,6 +132,14 @@ class Tee(Generic[T]):
*,
lock: Optional[ContextManager[Any]] = None,
):
"""Create a new ``tee``.
Args:
iterable: The iterable to split.
n: The number of iterators to create. Defaults to 2.
lock: The lock to synchronise access to the shared buffers.
Defaults to None.
"""
self._iterator = iter(iterable)
self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
self._children = tuple(
@@ -170,8 +192,8 @@ def batch_iterate(size: Optional[int], iterable: Iterable[T]) -> Iterator[List[T
size: The size of the batch. If None, returns a single batch.
iterable: The iterable to batch.
Returns:
An iterator over the batches.
Yields:
The batches of the iterable.
"""
it = iter(iterable)
while True:

View File

@@ -124,8 +124,7 @@ _json_markdown_re = re.compile(r"```(json)?(.*)", re.DOTALL)
def parse_json_markdown(
json_string: str, *, parser: Callable[[str], Any] = parse_partial_json
) -> dict:
"""
Parse a JSON string from a Markdown string.
"""Parse a JSON string from a Markdown string.
Args:
json_string: The Markdown string.
@@ -175,6 +174,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
Returns:
The parsed JSON object as a Python dictionary.
Raises:
OutputParserException: If the JSON string is invalid or does not contain
the expected keys.
"""
try:
json_obj = parse_json_markdown(text)

View File

@@ -90,7 +90,16 @@ def dereference_refs(
full_schema: Optional[dict] = None,
skip_keys: Optional[Sequence[str]] = None,
) -> dict:
"""Try to substitute $refs in JSON Schema."""
"""Try to substitute $refs in JSON Schema.
Args:
schema_obj: The schema object to dereference.
full_schema: The full schema object. Defaults to None.
skip_keys: The keys to skip. Defaults to None.
Returns:
The dereferenced schema object.
"""
full_schema = full_schema or schema_obj
skip_keys = (

View File

@@ -42,7 +42,15 @@ class ChevronError(SyntaxError):
def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
"""Parse a literal from the template."""
"""Parse a literal from the template.
Args:
template: The template to parse.
l_del: The left delimiter.
Returns:
Tuple[str, str]: The literal and the template.
"""
global _CURRENT_LINE
@@ -59,7 +67,16 @@ def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
"""Do a preliminary check to see if a tag could be a standalone."""
"""Do a preliminary check to see if a tag could be a standalone.
Args:
template: The template. (Not used.)
literal: The literal.
is_standalone: Whether the tag is standalone.
Returns:
bool: Whether the tag could be a standalone.
"""
# If there is a newline, or the previous tag was a standalone
if literal.find("\n") != -1 or is_standalone:
@@ -77,7 +94,16 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
"""Do a final check to see if a tag could be a standalone."""
"""Do a final check to see if a tag could be a standalone.
Args:
template: The template.
tag_type: The type of the tag.
is_standalone: Whether the tag is standalone.
Returns:
bool: Whether the tag could be a standalone.
"""
# Check right side if we might be a standalone
if is_standalone and tag_type not in ["variable", "no escape"]:
@@ -95,7 +121,20 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]:
"""Parse a tag from a template."""
"""Parse a tag from a template.
Args:
template: The template.
l_del: The left delimiter.
r_del: The right delimiter.
Returns:
Tuple[Tuple[str, str], str]: The tag and the template.
Raises:
ChevronError: If the tag is unclosed.
ChevronError: If the set delimiter tag is unclosed.
"""
global _CURRENT_LINE
global _LAST_TAG_LINE
@@ -404,36 +443,36 @@ def render(
Arguments:
template -- A file-like object or a string containing the template
template -- A file-like object or a string containing the template.
data -- A python dictionary with your data scope
data -- A python dictionary with your data scope.
partials_path -- The path to where your partials are stored
partials_path -- The path to where your partials are stored.
If set to None, then partials won't be loaded from the file system
(defaults to '.')
(defaults to '.').
partials_ext -- The extension that you want the parser to look for
(defaults to 'mustache')
(defaults to 'mustache').
partials_dict -- A python dictionary which will be search for partials
before the filesystem is. {'include': 'foo'} is the same
as a file called include.mustache
(defaults to {})
(defaults to {}).
padding -- This is for padding partials, and shouldn't be used
(but can be if you really want to)
(but can be if you really want to).
def_ldel -- The default left delimiter
("{{" by default, as in spec compliant mustache)
("{{" by default, as in spec compliant mustache).
def_rdel -- The default right delimiter
("}}" by default, as in spec compliant mustache)
("}}" by default, as in spec compliant mustache).
scopes -- The list of scopes that get_key will look through
scopes -- The list of scopes that get_key will look through.
warn -- Log a warning when a template substitution isn't found in the data
keep -- Keep unreplaced tags when a substitution isn't found in the data
keep -- Keep unreplaced tags when a substitution isn't found in the data.
Returns:

View File

@@ -21,12 +21,27 @@ PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
# How to type hint this?
def pre_init(func: Callable) -> Any:
"""Decorator to run a function before model initialization."""
"""Decorator to run a function before model initialization.
Args:
func (Callable): The function to run before model initialization.
Returns:
Any: The decorated function.
"""
@root_validator(pre=True)
@wraps(func)
def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]:
"""Decorator to run a function before model initialization."""
"""Decorator to run a function before model initialization.
Args:
cls (Type[BaseModel]): The model class.
values (Dict[str, Any]): The values to initialize the model with.
Returns:
Dict[str, Any]: The values to initialize the model with.
"""
# Insert default values
fields = cls.__fields__
for name, field_info in fields.items():

View File

@@ -36,5 +36,12 @@ def stringify_dict(data: dict) -> str:
def comma_list(items: List[Any]) -> str:
"""Convert a list to a comma-separated string."""
"""Convert a list to a comma-separated string.
Args:
items: The list to convert.
Returns:
str: The comma-separated string.
"""
return ", ".join(str(item) for item in items)

View File

@@ -15,7 +15,18 @@ from langchain_core.pydantic_v1 import SecretStr
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive."""
"""Validate specified keyword args are mutually exclusive."
Args:
*arg_groups (Tuple[str, ...]): Groups of mutually exclusive keyword args.
Returns:
Callable: Decorator that validates the specified keyword args
are mutually exclusive
Raises:
ValueError: If more than one arg in a group is defined.
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
@@ -41,7 +52,14 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
def raise_for_status_with_text(response: Response) -> None:
"""Raise an error with the response text."""
"""Raise an error with the response text.
Args:
response (Response): The response to check for errors.
Raises:
ValueError: If the response has an error status code.
"""
try:
response.raise_for_status()
except HTTPError as e:
@@ -52,6 +70,12 @@ def raise_for_status_with_text(response: Response) -> None:
def mock_now(dt_value): # type: ignore
"""Context manager for mocking out datetime.now() in unit tests.
Args:
dt_value: The datetime value to use for datetime.now().
Yields:
datetime.datetime: The mocked datetime class.
Example:
with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):
assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11)
@@ -86,7 +110,21 @@ def guard_import(
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
) -> Any:
"""Dynamically import a module and raise an exception if the module is not
installed."""
installed.
Args:
module_name (str): The name of the module to import.
pip_name (str, optional): The name of the module to install with pip.
Defaults to None.
package (str, optional): The package to import the module from.
Defaults to None.
Returns:
Any: The imported module.
Raises:
ImportError: If the module is not installed.
"""
try:
module = importlib.import_module(module_name, package)
except (ImportError, ModuleNotFoundError):
@@ -105,7 +143,22 @@ def check_package_version(
gt_version: Optional[str] = None,
gte_version: Optional[str] = None,
) -> None:
"""Check the version of a package."""
"""Check the version of a package.
Args:
package (str): The name of the package.
lt_version (str, optional): The version must be less than this.
Defaults to None.
lte_version (str, optional): The version must be less than or equal to this.
Defaults to None.
gt_version (str, optional): The version must be greater than this.
Defaults to None.
gte_version (str, optional): The version must be greater than or equal to this.
Defaults to None.
Raises:
ValueError: If the package version does not meet the requirements.
"""
imported_version = parse(version(package))
if lt_version is not None and imported_version >= parse(lt_version):
raise ValueError(
@@ -133,7 +186,11 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
"""Get field names, including aliases, for a pydantic class.
Args:
pydantic_cls: Pydantic class."""
pydantic_cls: Pydantic class.
Returns:
Set[str]: Field names.
"""
all_required_field_names = set()
for field in pydantic_cls.__fields__.values():
all_required_field_names.add(field.name)
@@ -153,6 +210,13 @@ def build_extra_kwargs(
extra_kwargs: Extra kwargs passed in by user.
values: Values passed in by user.
all_required_field_names: All required field names for the pydantic class.
Returns:
Dict[str, Any]: Extra kwargs.
Raises:
ValueError: If a field is specified in both values and extra_kwargs.
ValueError: If a field is specified in model_kwargs.
"""
for field_name in list(values):
if field_name in extra_kwargs:
@@ -176,7 +240,14 @@ def build_extra_kwargs(
def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
"""Convert a string to a SecretStr if needed.
Args:
value (Union[SecretStr, str]): The value to convert.
Returns:
SecretStr: The SecretStr value.
"""
if isinstance(value, SecretStr):
return value
return SecretStr(value)

View File

@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "langchain-core"
version = "0.2.14"
version = "0.2.19"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"
@@ -38,11 +38,6 @@ python = ">=3.12.4"
[tool.ruff.lint]
select = [ "E", "F", "I", "T201",]
[tool.ruff.lint.per-file-ignores]
"tests/unit_tests/prompts/test_chat.py" = ["E501"]
"tests/unit_tests/runnables/test_runnable.py" = ["E501"]
"tests/unit_tests/runnables/test_graph.py" = ["E501"]
[tool.coverage.run]
omit = [ "tests/*",]
@@ -66,6 +61,11 @@ optional = true
[tool.poetry.group.test_integration]
optional = true
[tool.ruff.lint.per-file-ignores]
"tests/unit_tests/prompts/test_chat.py" = [ "E501",]
"tests/unit_tests/runnables/test_runnable.py" = [ "E501",]
"tests/unit_tests/runnables/test_graph.py" = [ "E501",]
[tool.poetry.group.lint.dependencies]
ruff = "^0.5"
@@ -90,12 +90,6 @@ pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"
responses = "^0.25.0"
[tool.poetry.group.test.dependencies.langchain-standard-tests]
path = "../standard-tests"
develop = true
[[tool.poetry.group.test.dependencies.numpy]]
version = "^1.24.0"
python = "<3.12"
@@ -109,3 +103,7 @@ python = ">=3.12"
[tool.poetry.group.typing.dependencies.langchain-text-splitters]
path = "../text-splitters"
develop = true
[tool.poetry.group.test.dependencies.langchain-standard-tests]
path = "../standard-tests"
develop = true

View File

@@ -2,25 +2,21 @@ from datetime import datetime
from typing import (
Any,
AsyncIterator,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Type,
)
from unittest.mock import patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.embeddings import DeterministicFakeEmbedding
from langchain_core.indexing import InMemoryRecordManager, aindex, index
from langchain_core.indexing.api import _abatch, _HashedDocument
from langchain_core.vectorstores import VST, VectorStore
from langchain_core.vectorstores import InMemoryVectorStore, VectorStore
class ToyLoader(BaseLoader):
@@ -42,101 +38,6 @@ class ToyLoader(BaseLoader):
yield document
class InMemoryVectorStore(VectorStore):
"""In-memory implementation of VectorStore using a dictionary."""
def __init__(self, permit_upserts: bool = False) -> None:
"""Vector store interface for testing things in memory."""
self.store: Dict[str, Document] = {}
self.permit_upserts = permit_upserts
def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
"""Delete the given documents from the store using their IDs."""
if ids:
for _id in ids:
self.store.pop(_id, None)
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
"""Delete the given documents from the store using their IDs."""
if ids:
for _id in ids:
self.store.pop(_id, None)
def add_documents( # type: ignore
self,
documents: Sequence[Document],
*,
ids: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Add the given documents to the store (insert behavior)."""
if ids and len(ids) != len(documents):
raise ValueError(
f"Expected {len(ids)} ids, got {len(documents)} documents."
)
if not ids:
raise NotImplementedError("This is not implemented yet.")
for _id, document in zip(ids, documents):
if _id in self.store and not self.permit_upserts:
raise ValueError(
f"Document with uid {_id} already exists in the store."
)
self.store[_id] = document
return list(ids)
async def aadd_documents(
self,
documents: Sequence[Document],
*,
ids: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> List[str]:
if ids and len(ids) != len(documents):
raise ValueError(
f"Expected {len(ids)} ids, got {len(documents)} documents."
)
if not ids:
raise NotImplementedError("This is not implemented yet.")
for _id, document in zip(ids, documents):
if _id in self.store and not self.permit_upserts:
raise ValueError(
f"Document with uid {_id} already exists in the store."
)
self.store[_id] = document
return list(ids)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[Dict[Any, Any]]] = None,
**kwargs: Any,
) -> List[str]:
"""Add the given texts to the store (insert behavior)."""
raise NotImplementedError()
@classmethod
def from_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[Dict[Any, Any]]] = None,
**kwargs: Any,
) -> VST:
"""Create a vector store from a list of texts."""
raise NotImplementedError()
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Find the most similar documents to the given query."""
raise NotImplementedError()
@pytest.fixture
def record_manager() -> InMemoryRecordManager:
"""Timestamped set fixture."""
@@ -156,13 +57,15 @@ async def arecord_manager() -> InMemoryRecordManager:
@pytest.fixture
def vector_store() -> InMemoryVectorStore:
"""Vector store fixture."""
return InMemoryVectorStore()
embeddings = DeterministicFakeEmbedding(size=5)
return InMemoryVectorStore(embeddings)
@pytest.fixture
def upserting_vector_store() -> InMemoryVectorStore:
"""Vector store fixture."""
return InMemoryVectorStore(permit_upserts=True)
embeddings = DeterministicFakeEmbedding(size=5)
return InMemoryVectorStore(embeddings)
def test_indexing_same_content(
@@ -286,7 +189,7 @@ def test_index_simple_delete_full(
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {"mutated document 1", "This is another document."}
@@ -368,7 +271,7 @@ async def test_aindex_simple_delete_full(
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {"mutated document 1", "This is another document."}
@@ -659,7 +562,7 @@ def test_incremental_delete(
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {"This is another document.", "This is a test document."}
@@ -718,7 +621,7 @@ def test_incremental_delete(
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {
@@ -786,7 +689,7 @@ def test_incremental_indexing_with_batch_size(
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {"1", "2", "3", "4"}
@@ -836,7 +739,7 @@ def test_incremental_delete_with_batch_size(
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {"1", "2", "3", "4"}
@@ -981,7 +884,7 @@ async def test_aincremental_delete(
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {"This is another document.", "This is a test document."}
@@ -1040,7 +943,7 @@ async def test_aincremental_delete(
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
vector_store.get_by_ids([uid])[0].page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {
@@ -1232,8 +1135,10 @@ def test_deduplication_v2(
# using in memory implementation here
assert isinstance(vector_store, InMemoryVectorStore)
ids = list(vector_store.store.keys())
contents = sorted(
[document.page_content for document in vector_store.store.values()]
[document.page_content for document in vector_store.get_by_ids(ids)]
)
assert contents == ["1", "2", "3"]
@@ -1370,11 +1275,19 @@ def test_indexing_custom_batch_size(
ids = [_HashedDocument.from_document(doc).uid for doc in docs]
batch_size = 1
with patch.object(vector_store, "add_documents") as mock_add_documents:
original = vector_store.add_documents
try:
mock_add_documents = MagicMock()
vector_store.add_documents = mock_add_documents # type: ignore
index(docs, record_manager, vector_store, batch_size=batch_size)
args, kwargs = mock_add_documents.call_args
assert args == (docs,)
assert kwargs == {"ids": ids, "batch_size": batch_size}
finally:
vector_store.add_documents = original # type: ignore
async def test_aindexing_custom_batch_size(
@@ -1390,8 +1303,9 @@ async def test_aindexing_custom_batch_size(
ids = [_HashedDocument.from_document(doc).uid for doc in docs]
batch_size = 1
with patch.object(vector_store, "aadd_documents") as mock_add_documents:
await aindex(docs, arecord_manager, vector_store, batch_size=batch_size)
args, kwargs = mock_add_documents.call_args
assert args == (docs,)
assert kwargs == {"ids": ids, "batch_size": batch_size}
mock_add_documents = AsyncMock()
vector_store.aadd_documents = mock_add_documents # type: ignore
await aindex(docs, arecord_manager, vector_store, batch_size=batch_size)
args, kwargs = mock_add_documents.call_args
assert args == (docs,)
assert kwargs == {"ids": ids, "batch_size": batch_size}

View File

@@ -279,7 +279,7 @@ class CustomChat(GenericFakeChatModel):
async def test_can_swap_caches() -> None:
"""Test that we can use a different cache object.
This test verifies that when we fetch teh llm_string representation
This test verifies that when we fetch the llm_string representation
of the chat model, we can swap the cache object and still get the same
result.
"""

View File

@@ -1,19 +1,16 @@
from langchain_core.load import dumpd, load
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
InvalidToolCall,
ToolCall,
ToolCallChunk,
)
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
def test_serdes_message() -> None:
msg = AIMessage(
content=[{"text": "blah", "type": "text"}],
tool_calls=[ToolCall(name="foo", args={"bar": 1}, id="baz")],
tool_calls=[create_tool_call(name="foo", args={"bar": 1}, id="baz")],
invalid_tool_calls=[
InvalidToolCall(name="foobad", args="blah", id="booz", error="bad")
create_invalid_tool_call(name="foobad", args="blah", id="booz", error="bad")
],
)
expected = {
@@ -23,9 +20,17 @@ def test_serdes_message() -> None:
"kwargs": {
"type": "ai",
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"tool_calls": [
{"name": "foo", "args": {"bar": 1}, "id": "baz", "type": "tool_call"}
],
"invalid_tool_calls": [
{"name": "foobad", "args": "blah", "id": "booz", "error": "bad"}
{
"name": "foobad",
"args": "blah",
"id": "booz",
"error": "bad",
"type": "invalid_tool_call",
}
],
},
}
@@ -38,8 +43,13 @@ def test_serdes_message_chunk() -> None:
chunk = AIMessageChunk(
content=[{"text": "blah", "type": "text"}],
tool_call_chunks=[
ToolCallChunk(name="foo", args='{"bar": 1}', id="baz", index=0),
ToolCallChunk(name="foobad", args="blah", id="booz", index=1),
create_tool_call_chunk(name="foo", args='{"bar": 1}', id="baz", index=0),
create_tool_call_chunk(
name="foobad",
args="blah",
id="booz",
index=1,
),
],
)
expected = {
@@ -49,18 +59,33 @@ def test_serdes_message_chunk() -> None:
"kwargs": {
"type": "AIMessageChunk",
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"tool_calls": [
{"name": "foo", "args": {"bar": 1}, "id": "baz", "type": "tool_call"}
],
"invalid_tool_calls": [
{
"name": "foobad",
"args": "blah",
"id": "booz",
"error": None,
"type": "invalid_tool_call",
}
],
"tool_call_chunks": [
{"name": "foo", "args": '{"bar": 1}', "id": "baz", "index": 0},
{"name": "foobad", "args": "blah", "id": "booz", "index": 1},
{
"name": "foo",
"args": '{"bar": 1}',
"id": "baz",
"index": 0,
"type": "tool_call_chunk",
},
{
"name": "foobad",
"args": "blah",
"id": "booz",
"index": 1,
"type": "tool_call_chunk",
},
],
},
}

Some files were not shown because too many files have changed in this diff Show More