Compare commits

..

20 Commits

Author SHA1 Message Date
Sydney Runkle
89d10ca1a9 new typing 2025-10-11 07:34:42 -04:00
Sydney Runkle
760fc3bc12 chore(langchain_v1): use args for HITL (#33442) 2025-10-11 07:12:46 -04:00
Eugene Yurtsev
e3fc7d8aa6 chore(langchain_v1): bump release version (#33440)
bump v1 for release
2025-10-10 21:51:00 -04:00
Eugene Yurtsev
2b3b209e40 chore(langchain_v1): improve error message (#33433)
Make error messages actionable for sync / async decorators
2025-10-10 17:18:20 -04:00
ccurme
78903ac285 fix(openai): conditionally skip test (#33431) 2025-10-10 21:04:18 +00:00
ccurme
f361acc11c chore(anthropic): speed up integration tests (#33430) 2025-10-10 20:57:44 +00:00
Eugene Yurtsev
ed185c0026 chore(langchain_v1): remove langchain_text_splitters from test group (#33425)
Remove langchain_text_splitters from test group in langchain_v1
2025-10-10 16:56:14 -04:00
Eugene Yurtsev
6dc34beb71 chore(langchain_v1): stricter handling of sync vs. async for wrap_model_call and wrap_tool_call (#33429)
Wrap model call and wrap tool call
2025-10-10 16:54:42 -04:00
Eugene Yurtsev
c2205f88e6 chore(langchain_v1): further namespace clean up (#33428)
Reduce exposed namespace for now
2025-10-10 20:48:24 +00:00
ccurme
abdbe185c5 release(anthropic): 1.0.0a4 (#33427) 2025-10-10 16:39:58 -04:00
ccurme
c1b816cb7e fix(fireworks): parse standard blocks in input (#33426) 2025-10-10 16:18:37 -04:00
Eugene Yurtsev
0559558715 feat(langchain_v1): add async implementation for wrap_tool_call (#33420)
Add async implementation. No automatic delegation to sync at the moment.
2025-10-10 15:07:19 -04:00
Eugene Yurtsev
75965474fc chore(langchain_v1): tool error exceptions (#33424)
Tool error exceptions
2025-10-10 15:06:40 -04:00
Mason Daugherty
5dc014fdf4 chore(core): delete get_relevant_documents (#33378)
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2025-10-10 14:51:54 -04:00
Mason Daugherty
291a9fcea1 style: llm -> model (#33423) 2025-10-10 13:19:13 -04:00
Christophe Bornet
dd994b9d7f chore(langchain): remove arg types from docstrings (#33413)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
2025-10-10 11:51:00 -04:00
Christophe Bornet
83901b30e3 chore(text-splitters): remove arg types from docstrings (#33406)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
2025-10-10 11:37:53 -04:00
Mason Daugherty
bcfa21a6e7 chore(infra): remove Poetry setup and dependencies (#33418)
AWS now uses UV
2025-10-10 11:29:52 -04:00
ccurme
af1da28459 feat(langchain_v1): expand message exports (#33419) 2025-10-10 15:14:51 +00:00
Mason Daugherty
ed2ee4e8cc style: fix tables, capitalization (#33417) 2025-10-10 11:09:59 -04:00
156 changed files with 2303 additions and 2316 deletions

View File

@@ -23,10 +23,8 @@ permissions:
contents: read
env:
POETRY_VERSION: "1.8.4"
UV_FROZEN: "true"
DEFAULT_LIBS: '["libs/partners/openai", "libs/partners/anthropic", "libs/partners/fireworks", "libs/partners/groq", "libs/partners/mistralai", "libs/partners/xai", "libs/partners/google-vertexai", "libs/partners/google-genai", "libs/partners/aws"]'
POETRY_LIBS: ("libs/partners/aws")
jobs:
# Generate dynamic test matrix based on input parameters or defaults
@@ -60,7 +58,6 @@ jobs:
echo $matrix
echo "matrix=$matrix" >> $GITHUB_OUTPUT
# Run integration tests against partner libraries with live API credentials
# Tests are run with Poetry or UV depending on the library's setup
build:
if: github.repository_owner == 'langchain-ai' || github.event_name != 'schedule'
name: "🐍 Python ${{ matrix.python-version }}: ${{ matrix.working-directory }}"
@@ -95,17 +92,7 @@ jobs:
mv langchain-google/libs/vertexai langchain/libs/partners/google-vertexai
mv langchain-aws/libs/aws langchain/libs/partners/aws
- name: "🐍 Set up Python ${{ matrix.python-version }} + Poetry"
if: contains(env.POETRY_LIBS, matrix.working-directory)
uses: "./langchain/.github/actions/poetry_setup"
with:
python-version: ${{ matrix.python-version }}
poetry-version: ${{ env.POETRY_VERSION }}
working-directory: langchain/${{ matrix.working-directory }}
cache-key: scheduled
- name: "🐍 Set up Python ${{ matrix.python-version }} + UV"
if: "!contains(env.POETRY_LIBS, matrix.working-directory)"
uses: "./langchain/.github/actions/uv_setup"
with:
python-version: ${{ matrix.python-version }}
@@ -123,15 +110,7 @@ jobs:
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ secrets.AWS_REGION }}
- name: "📦 Install Dependencies (Poetry)"
if: contains(env.POETRY_LIBS, matrix.working-directory)
run: |
echo "Running scheduled tests, installing dependencies with poetry..."
cd langchain/${{ matrix.working-directory }}
poetry install --with=test_integration,test
- name: "📦 Install Dependencies (UV)"
if: "!contains(env.POETRY_LIBS, matrix.working-directory)"
- name: "📦 Install Dependencies"
run: |
echo "Running scheduled tests, installing dependencies with uv..."
cd langchain/${{ matrix.working-directory }}

View File

@@ -152,11 +152,11 @@ def send_email(to: str, msg: str, *, priority: str = "normal") -> bool:
priority: Email priority level (`'low'`, `'normal'`, `'high'`).
Returns:
True if email was sent successfully, False otherwise.
`True` if email was sent successfully, `False` otherwise.
Raises:
InvalidEmailError: If the email address format is invalid.
SMTPConnectionError: If unable to connect to email server.
`InvalidEmailError`: If the email address format is invalid.
`SMTPConnectionError`: If unable to connect to email server.
"""
```

View File

@@ -152,11 +152,11 @@ def send_email(to: str, msg: str, *, priority: str = "normal") -> bool:
priority: Email priority level (`'low'`, `'normal'`, `'high'`).
Returns:
True if email was sent successfully, False otherwise.
`True` if email was sent successfully, `False` otherwise.
Raises:
InvalidEmailError: If the email address format is invalid.
SMTPConnectionError: If unable to connect to email server.
`InvalidEmailError`: If the email address format is invalid.
`SMTPConnectionError`: If unable to connect to email server.
"""
```

View File

@@ -19,8 +19,8 @@ And you should configure credentials by setting the following environment variab
```python
from __module_name__ import Chat__ModuleName__
llm = Chat__ModuleName__()
llm.invoke("Sing a ballad of LangChain.")
model = Chat__ModuleName__()
model.invoke("Sing a ballad of LangChain.")
```
## Embeddings
@@ -41,6 +41,6 @@ embeddings.embed_query("What is the meaning of life?")
```python
from __module_name__ import __ModuleName__LLM
llm = __ModuleName__LLM()
llm.invoke("The meaning of life is")
model = __ModuleName__LLM()
model.invoke("The meaning of life is")
```

View File

@@ -72,7 +72,9 @@
"cell_type": "markdown",
"id": "72ee0c4b-9764-423a-9dbf-95129e185210",
"metadata": {},
"source": "To enable automated tracing of your model calls, set your [LangSmith](https://docs.smith.langchain.com/) API key:"
"source": [
"To enable automated tracing of your model calls, set your [LangSmith](https://docs.smith.langchain.com/) API key:"
]
},
{
"cell_type": "code",
@@ -126,7 +128,7 @@
"source": [
"from __module_name__ import Chat__ModuleName__\n",
"\n",
"llm = Chat__ModuleName__(\n",
"model = Chat__ModuleName__(\n",
" model=\"model-name\",\n",
" temperature=0,\n",
" max_tokens=None,\n",
@@ -162,7 +164,7 @@
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg = model.invoke(messages)\n",
"ai_msg"
]
},
@@ -207,7 +209,7 @@
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain = prompt | model\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",

View File

@@ -65,7 +65,9 @@
"cell_type": "markdown",
"id": "4b6e1ca6",
"metadata": {},
"source": "To enable automated tracing of your model calls, set your [LangSmith](https://docs.smith.langchain.com/) API key:"
"source": [
"To enable automated tracing of your model calls, set your [LangSmith](https://docs.smith.langchain.com/) API key:"
]
},
{
"cell_type": "code",
@@ -119,7 +121,7 @@
"source": [
"from __module_name__ import __ModuleName__LLM\n",
"\n",
"llm = __ModuleName__LLM(\n",
"model = __ModuleName__LLM(\n",
" model=\"model-name\",\n",
" temperature=0,\n",
" max_tokens=None,\n",
@@ -141,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "035dea0f",
"metadata": {
"tags": []
@@ -150,7 +152,7 @@
"source": [
"input_text = \"__ModuleName__ is an AI company that \"\n",
"\n",
"completion = llm.invoke(input_text)\n",
"completion = model.invoke(input_text)\n",
"completion"
]
},
@@ -177,7 +179,7 @@
"\n",
"prompt = PromptTemplate(\"How to say {input} in {output_language}:\\n\")\n",
"\n",
"chain = prompt | llm\n",
"chain = prompt | model\n",
"chain.invoke(\n",
" {\n",
" \"output_language\": \"German\",\n",

View File

@@ -155,7 +155,7 @@
"\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)"
"model = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)"
]
},
{
@@ -185,7 +185,7 @@
"chain = (\n",
" {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
" | prompt\n",
" | llm\n",
" | model\n",
" | StrOutputParser()\n",
")"
]

View File

@@ -192,7 +192,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "af3123ad-7a02-40e5-b58e-7d56e23e5830",
"metadata": {},
"outputs": [],
@@ -203,7 +203,7 @@
"# !pip install -qU langchain langchain-openai\n",
"from langchain.chat_models import init_chat_model\n",
"\n",
"llm = init_chat_model(model=\"gpt-4o\", model_provider=\"openai\")"
"model = init_chat_model(model=\"gpt-4o\", model_provider=\"openai\")"
]
},
{
@@ -216,7 +216,7 @@
"from langgraph.prebuilt import create_react_agent\n",
"\n",
"tools = [tool]\n",
"agent = create_react_agent(llm, tools)"
"agent = create_react_agent(model, tools)"
]
},
{

View File

@@ -60,7 +60,7 @@ class Chat__ModuleName__(BaseChatModel):
```python
from __module_name__ import Chat__ModuleName__
llm = Chat__ModuleName__(
model = Chat__ModuleName__(
model="...",
temperature=0,
max_tokens=None,
@@ -77,7 +77,7 @@ class Chat__ModuleName__(BaseChatModel):
("system", "You are a helpful translator. Translate the user sentence to French."),
("human", "I love programming."),
]
llm.invoke(messages)
model.invoke(messages)
```
```python
@@ -87,7 +87,7 @@ class Chat__ModuleName__(BaseChatModel):
# TODO: Delete if token-level streaming isn't supported.
Stream:
```python
for chunk in llm.stream(messages):
for chunk in model.stream(messages):
print(chunk.text, end="")
```
@@ -96,7 +96,7 @@ class Chat__ModuleName__(BaseChatModel):
```
```python
stream = llm.stream(messages)
stream = model.stream(messages)
full = next(stream)
for chunk in stream:
full += chunk
@@ -110,13 +110,13 @@ class Chat__ModuleName__(BaseChatModel):
# TODO: Delete if native async isn't supported.
Async:
```python
await llm.ainvoke(messages)
await model.ainvoke(messages)
# stream:
# async for chunk in (await llm.astream(messages))
# async for chunk in (await model.astream(messages))
# batch:
# await llm.abatch([messages])
# await model.abatch([messages])
```
```python
@@ -137,8 +137,8 @@ class Chat__ModuleName__(BaseChatModel):
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
model_with_tools = model.bind_tools([GetWeather, GetPopulation])
ai_msg = model_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
ai_msg.tool_calls
```
@@ -162,8 +162,8 @@ class Chat__ModuleName__(BaseChatModel):
punchline: str = Field(description="The punchline to the joke")
rating: int | None = Field(description="How funny the joke is, from 1 to 10")
structured_llm = llm.with_structured_output(Joke)
structured_llm.invoke("Tell me a joke about cats")
structured_model = model.with_structured_output(Joke)
structured_model.invoke("Tell me a joke about cats")
```
```python
@@ -176,8 +176,8 @@ class Chat__ModuleName__(BaseChatModel):
JSON mode:
```python
# TODO: Replace with appropriate bind arg.
json_llm = llm.bind(response_format={"type": "json_object"})
ai_msg = json_llm.invoke("Return a JSON object with key 'random_ints' and a value of 10 random ints in [0-99]")
json_model = model.bind(response_format={"type": "json_object"})
ai_msg = json_model.invoke("Return a JSON object with key 'random_ints' and a value of 10 random ints in [0-99]")
ai_msg.content
```
@@ -204,7 +204,7 @@ class Chat__ModuleName__(BaseChatModel):
},
],
)
ai_msg = llm.invoke([message])
ai_msg = model.invoke([message])
ai_msg.content
```
@@ -235,7 +235,7 @@ class Chat__ModuleName__(BaseChatModel):
# TODO: Delete if token usage metadata isn't supported.
Token usage:
```python
ai_msg = llm.invoke(messages)
ai_msg = model.invoke(messages)
ai_msg.usage_metadata
```
@@ -247,8 +247,8 @@ class Chat__ModuleName__(BaseChatModel):
Logprobs:
```python
# TODO: Replace with appropriate bind arg.
logprobs_llm = llm.bind(logprobs=True)
ai_msg = logprobs_llm.invoke(messages)
logprobs_model = model.bind(logprobs=True)
ai_msg = logprobs_model.invoke(messages)
ai_msg.response_metadata["logprobs"]
```
@@ -257,7 +257,7 @@ class Chat__ModuleName__(BaseChatModel):
```
Response metadata
```python
ai_msg = llm.invoke(messages)
ai_msg = model.invoke(messages)
ai_msg.response_metadata
```

View File

@@ -65,7 +65,7 @@ class __ModuleName__Retriever(BaseRetriever):
Question: {question}\"\"\"
)
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
model = ChatOpenAI(model="gpt-3.5-turbo-0125")
def format_docs(docs):
return "\\n\\n".join(doc.page_content for doc in docs)
@@ -73,7 +73,7 @@ class __ModuleName__Retriever(BaseRetriever):
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| model
| StrOutputParser()
)

View File

@@ -65,7 +65,7 @@ def is_subclass(class_obj: type, classes_: list[type]) -> bool:
classes_: A list of classes to check against.
Returns:
True if `class_obj` is a subclass of any class in `classes_`, False otherwise.
True if `class_obj` is a subclass of any class in `classes_`, `False` otherwise.
"""
return any(
issubclass(class_obj, kls)

View File

@@ -35,7 +35,7 @@ def is_openai_data_block(
different type, this function will return False.
Returns:
True if the block is a valid OpenAI data block and matches the filter_
`True` if the block is a valid OpenAI data block and matches the filter_
(if provided).
"""

View File

@@ -123,7 +123,6 @@ class BaseLanguageModel(
* If instance of `BaseCache`, will use the provided cache.
Caching is not currently supported for streaming methods of models.
"""
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
"""Whether to print out response text."""
@@ -146,7 +145,7 @@ class BaseLanguageModel(
def set_verbose(cls, verbose: bool | None) -> bool: # noqa: FBT001
"""If verbose is `None`, set it.
This allows users to pass in None as verbose to access the global setting.
This allows users to pass in `None` as verbose to access the global setting.
Args:
verbose: The verbosity setting to use.
@@ -186,12 +185,12 @@ class BaseLanguageModel(
1. Take advantage of batched calls,
2. Need more output from the model than just the top generated value,
3. Are building chains that are agnostic to the underlying language model
type (e.g., pure text completion models vs chat models).
type (e.g., pure text completion models vs chat models).
Args:
prompts: List of PromptValues. A PromptValue is an object that can be
converted to match the format of any language model (string for pure
text generation models and BaseMessages for chat models).
prompts: List of `PromptValue` objects. A `PromptValue` is an object that
can be converted to match the format of any language model (string for
pure text generation models and `BaseMessage` objects for chat models).
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
@@ -200,8 +199,8 @@ class BaseLanguageModel(
to the model provider API call.
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
An `LLMResult`, which contains a list of candidate `Generation` objects for
each input prompt and additional model provider-specific output.
"""
@@ -223,12 +222,12 @@ class BaseLanguageModel(
1. Take advantage of batched calls,
2. Need more output from the model than just the top generated value,
3. Are building chains that are agnostic to the underlying language model
type (e.g., pure text completion models vs chat models).
type (e.g., pure text completion models vs chat models).
Args:
prompts: List of PromptValues. A PromptValue is an object that can be
converted to match the format of any language model (string for pure
text generation models and BaseMessages for chat models).
prompts: List of `PromptValue` objects. A `PromptValue` is an object that
can be converted to match the format of any language model (string for
pure text generation models and `BaseMessage` objects for chat models).
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
@@ -237,8 +236,8 @@ class BaseLanguageModel(
to the model provider API call.
Returns:
An `LLMResult`, which contains a list of candidate Generations for each
input prompt and additional model provider-specific output.
An `LLMResult`, which contains a list of candidate `Generation` objects for
each input prompt and additional model provider-specific output.
"""

View File

@@ -240,78 +240,54 @@ def _format_ls_structured_output(ls_structured_output_format: dict | None) -> di
class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
"""Base class for chat models.
r"""Base class for chat models.
Key imperative methods:
Methods that actually call the underlying model.
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
| Method | Input | Output | Description |
+===========================+================================================================+=====================================================================+==================================================================================================+
| `invoke` | str | list[dict | tuple | BaseMessage] | PromptValue | BaseMessage | A single chat model call. |
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
| `ainvoke` | ''' | BaseMessage | Defaults to running invoke in an async executor. |
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
| `stream` | ''' | Iterator[BaseMessageChunk] | Defaults to yielding output of invoke. |
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
| `astream` | ''' | AsyncIterator[BaseMessageChunk] | Defaults to yielding output of ainvoke. |
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
| `astream_events` | ''' | AsyncIterator[StreamEvent] | Event types: 'on_chat_model_start', 'on_chat_model_stream', 'on_chat_model_end'. |
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
| `batch` | list['''] | list[BaseMessage] | Defaults to running invoke in concurrent threads. |
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
| `abatch` | list['''] | list[BaseMessage] | Defaults to running ainvoke in concurrent threads. |
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
| `batch_as_completed` | list['''] | Iterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running invoke in concurrent threads. |
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
| `abatch_as_completed` | list['''] | AsyncIterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running ainvoke in concurrent threads. |
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
This table provides a brief overview of the main imperative methods. Please see the base `Runnable` reference for full documentation.
This table provides a brief overview of the main imperative methods. Please see the base Runnable reference for full documentation.
| Method | Input | Output | Description |
| ---------------------- | ------------------------------------------------------------ | ---------------------------------------------------------- | -------------------------------------------------------------------------------- |
| `invoke` | `str` \| `list[dict | tuple | BaseMessage]` \| `PromptValue` | `BaseMessage` | A single chat model call. |
| `ainvoke` | `'''` | `BaseMessage` | Defaults to running `invoke` in an async executor. |
| `stream` | `'''` | `Iterator[BaseMessageChunk]` | Defaults to yielding output of `invoke`. |
| `astream` | `'''` | `AsyncIterator[BaseMessageChunk]` | Defaults to yielding output of `ainvoke`. |
| `astream_events` | `'''` | `AsyncIterator[StreamEvent]` | Event types: `on_chat_model_start`, `on_chat_model_stream`, `on_chat_model_end`. |
| `batch` | `list[''']` | `list[BaseMessage]` | Defaults to running `invoke` in concurrent threads. |
| `abatch` | `list[''']` | `list[BaseMessage]` | Defaults to running `ainvoke` in concurrent threads. |
| `batch_as_completed` | `list[''']` | `Iterator[tuple[int, Union[BaseMessage, Exception]]]` | Defaults to running `invoke` in concurrent threads. |
| `abatch_as_completed` | `list[''']` | `AsyncIterator[tuple[int, Union[BaseMessage, Exception]]]` | Defaults to running `ainvoke` in concurrent threads. |
Key declarative methods:
Methods for creating another Runnable using the ChatModel.
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
| Method | Description |
+==================================+===========================================================================================================+
| `bind_tools` | Create ChatModel that can call tools. |
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
| `with_structured_output` | Create wrapper that structures model output using schema. |
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
| `with_retry` | Create wrapper that retries model calls on failure. |
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
| `with_fallbacks` | Create wrapper that falls back to other models on failure. |
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
| `configurable_fields` | Specify init args of the model that can be configured at runtime via the RunnableConfig. |
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
| `configurable_alternatives` | Specify alternative models which can be swapped in at runtime via the RunnableConfig. |
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
Methods for creating another `Runnable` using the chat model.
This table provides a brief overview of the main declarative methods. Please see the reference for each method for full documentation.
| Method | Description |
| ---------------------------- | -------------------------------------------------------------------------------------------- |
| `bind_tools` | Create chat model that can call tools. |
| `with_structured_output` | Create wrapper that structures model output using schema. |
| `with_retry` | Create wrapper that retries model calls on failure. |
| `with_fallbacks` | Create wrapper that falls back to other models on failure. |
| `configurable_fields` | Specify init args of the model that can be configured at runtime via the `RunnableConfig`. |
| `configurable_alternatives` | Specify alternative models which can be swapped in at runtime via the `RunnableConfig`. |
Creating custom chat model:
Custom chat model implementations should inherit from this class.
Please reference the table below for information about which
methods and properties are required or optional for implementations.
+----------------------------------+--------------------------------------------------------------------+-------------------+
| Method/Property | Description | Required/Optional |
+==================================+====================================================================+===================+
| -------------------------------- | ------------------------------------------------------------------ | ----------------- |
| `_generate` | Use to generate a chat result from a prompt | Required |
+----------------------------------+--------------------------------------------------------------------+-------------------+
| `_llm_type` (property) | Used to uniquely identify the type of the model. Used for logging. | Required |
+----------------------------------+--------------------------------------------------------------------+-------------------+
| `_identifying_params` (property) | Represent model parameterization for tracing purposes. | Optional |
+----------------------------------+--------------------------------------------------------------------+-------------------+
| `_stream` | Use to implement streaming | Optional |
+----------------------------------+--------------------------------------------------------------------+-------------------+
| `_agenerate` | Use to implement a native async method | Optional |
+----------------------------------+--------------------------------------------------------------------+-------------------+
| `_astream` | Use to implement async version of `_stream` | Optional |
+----------------------------------+--------------------------------------------------------------------+-------------------+
Follow the guide for more information on how to implement a custom Chat Model:
Follow the guide for more information on how to implement a custom chat model:
[Guide](https://python.langchain.com/docs/how_to/custom_chat_model/).
""" # noqa: E501
@@ -327,9 +303,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
- If `True`, will always bypass streaming case.
- If `'tool_calling'`, will bypass streaming case only when the model is called
with a `tools` keyword argument. In other words, LangChain will automatically
switch to non-streaming behavior (`invoke`) only when the tools argument is
provided. This offers the best of both worlds.
with a `tools` keyword argument. In other words, LangChain will automatically
switch to non-streaming behavior (`invoke`) only when the tools argument is
provided. This offers the best of both worlds.
- If `False` (Default), will always use streaming case if available.
The main reason for this flag is that code might be written using `stream` and
@@ -349,7 +325,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
Supported values:
- `'v0'`: provider-specific format in content (can lazily-parse with
`.content_blocks`)
`.content_blocks`)
- `'v1'`: standardized format in content (consistent with `.content_blocks`)
Partner packages (e.g., `langchain-openai`) can also use this field to roll out
@@ -1579,10 +1555,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
justification: str
llm = ChatModel(model="model-name", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification)
model = ChatModel(model="model-name", temperature=0)
structured_model = model.with_structured_output(AnswerWithJustification)
structured_llm.invoke(
structured_model.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
@@ -1604,12 +1580,12 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
justification: str
llm = ChatModel(model="model-name", temperature=0)
structured_llm = llm.with_structured_output(
model = ChatModel(model="model-name", temperature=0)
structured_model = model.with_structured_output(
AnswerWithJustification, include_raw=True
)
structured_llm.invoke(
structured_model.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> {
@@ -1633,10 +1609,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = ChatModel(model="model-name", temperature=0)
structured_llm = llm.with_structured_output(dict_schema)
model = ChatModel(model="model-name", temperature=0)
structured_model = model.with_structured_output(dict_schema)
structured_llm.invoke(
structured_model.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> {

View File

@@ -25,9 +25,9 @@ class BaseSerialized(TypedDict):
id: list[str]
"""The unique identifier of the object."""
name: NotRequired[str]
"""The name of the object. Optional."""
"""The name of the object."""
graph: NotRequired[dict[str, Any]]
"""The graph of the object. Optional."""
"""The graph of the object."""
class SerializedConstructor(BaseSerialized):
@@ -52,7 +52,7 @@ class SerializedNotImplemented(BaseSerialized):
type: Literal["not_implemented"]
"""The type of the object. Must be `'not_implemented'`."""
repr: str | None
"""The representation of the object. Optional."""
"""The representation of the object."""
def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
@@ -61,7 +61,7 @@ def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
Args:
value: The value.
key: The key.
model: The pydantic model.
model: The Pydantic model.
Returns:
Whether the value is different from the default.
@@ -93,18 +93,18 @@ class Serializable(BaseModel, ABC):
It relies on the following methods and properties:
- `is_lc_serializable`: Is this class serializable?
By design, even if a class inherits from Serializable, it is not serializable by
default. This is to prevent accidental serialization of objects that should not
be serialized.
By design, even if a class inherits from `Serializable`, it is not serializable
by default. This is to prevent accidental serialization of objects that should
not be serialized.
- `get_lc_namespace`: Get the namespace of the langchain object.
During deserialization, this namespace is used to identify
the correct class to instantiate.
Please see the `Reviver` class in `langchain_core.load.load` for more details.
During deserialization an additional mapping is handle
classes that have moved or been renamed across package versions.
During deserialization an additional mapping is handle classes that have moved
or been renamed across package versions.
- `lc_secrets`: A map of constructor argument names to secret ids.
- `lc_attributes`: List of additional attribute names that should be included
as part of the serialized representation.
as part of the serialized representation.
"""
# Remove default BaseModel init docstring.
@@ -116,12 +116,12 @@ class Serializable(BaseModel, ABC):
def is_lc_serializable(cls) -> bool:
"""Is this class serializable?
By design, even if a class inherits from Serializable, it is not serializable by
default. This is to prevent accidental serialization of objects that should not
be serialized.
By design, even if a class inherits from `Serializable`, it is not serializable
by default. This is to prevent accidental serialization of objects that should
not be serialized.
Returns:
Whether the class is serializable. Default is False.
Whether the class is serializable. Default is `False`.
"""
return False
@@ -133,7 +133,7 @@ class Serializable(BaseModel, ABC):
namespace is ["langchain", "llms", "openai"]
Returns:
The namespace as a list of strings.
The namespace.
"""
return cls.__module__.split(".")
@@ -141,8 +141,7 @@ class Serializable(BaseModel, ABC):
def lc_secrets(self) -> dict[str, str]:
"""A map of constructor argument names to secret ids.
For example,
{"openai_api_key": "OPENAI_API_KEY"}
For example, `{"openai_api_key": "OPENAI_API_KEY"}`
"""
return {}
@@ -151,6 +150,7 @@ class Serializable(BaseModel, ABC):
"""List of attribute names that should be included in the serialized kwargs.
These attributes must be accepted by the constructor.
Default is an empty dictionary.
"""
return {}
@@ -194,7 +194,7 @@ class Serializable(BaseModel, ABC):
ValueError: If the class has deprecated attributes.
Returns:
A json serializable object or a SerializedNotImplemented object.
A json serializable object or a `SerializedNotImplemented` object.
"""
if not self.is_lc_serializable():
return self.to_json_not_implemented()
@@ -269,7 +269,7 @@ class Serializable(BaseModel, ABC):
"""Serialize a "not implemented" object.
Returns:
SerializedNotImplemented.
`SerializedNotImplemented`.
"""
return to_json_not_implemented(self)
@@ -284,8 +284,8 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
Returns:
Whether the field is useful. If the field is required, it is useful.
If the field is not required, it is useful if the value is not None.
If the field is not required and the value is None, it is useful if the
If the field is not required, it is useful if the value is not `None`.
If the field is not required and the value is `None`, it is useful if the
default value is different from the value.
"""
field = type(inst).model_fields.get(key)
@@ -344,10 +344,10 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
"""Serialize a "not implemented" object.
Args:
obj: object to serialize.
obj: Object to serialize.
Returns:
SerializedNotImplemented
`SerializedNotImplemented`
"""
id_: list[str] = []
try:

View File

@@ -877,7 +877,7 @@ def is_data_content_block(block: dict) -> bool:
block: The content block to check.
Returns:
True if the content block is a data content block, False otherwise.
`True` if the content block is a data content block, `False` otherwise.
"""
if block.get("type") not in _get_data_content_block_types():

View File

@@ -31,13 +31,13 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
partial: Whether to parse partial JSON objects.
Returns:
The parsed JSON object.
Raises:
OutputParserException: If the output is not valid JSON.
`OutputParserException`: If the output is not valid JSON.
"""
generation = result[0]
if not isinstance(generation, ChatGeneration):
@@ -56,7 +56,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse an output as the Json object."""
"""Parse an output as the JSON object."""
strict: bool = False
"""Whether to allow non-JSON-compliant strings.
@@ -82,13 +82,13 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
partial: Whether to parse partial JSON objects.
Returns:
The parsed JSON object.
Raises:
OutputParserException: If the output is not valid JSON.
OutputParserExcept`ion: If the output is not valid JSON.
"""
if len(result) != 1:
msg = f"Expected exactly one result, but got {len(result)}"
@@ -155,7 +155,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
"""Parse an output as the element of the Json object."""
"""Parse an output as the element of the JSON object."""
key_name: str
"""The name of the key to return."""
@@ -165,7 +165,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
partial: Whether to parse partial JSON objects.
Returns:
The parsed JSON object.
@@ -177,16 +177,15 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""Parse an output as a pydantic object.
"""Parse an output as a Pydantic object.
This parser is used to parse the output of a ChatModel that uses
OpenAI function format to invoke functions.
This parser is used to parse the output of a chat model that uses OpenAI function
format to invoke functions.
The parser extracts the function call invocation and matches
them to the pydantic schema provided.
The parser extracts the function call invocation and matches them to the Pydantic
schema provided.
An exception will be raised if the function call does not match
the provided schema.
An exception will be raised if the function call does not match the provided schema.
Example:
```python
@@ -221,7 +220,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""
pydantic_schema: type[BaseModel] | dict[str, type[BaseModel]]
"""The pydantic schema to parse the output with.
"""The Pydantic schema to parse the output with.
If multiple schemas are provided, then the function name will be used to
determine which schema to use.
@@ -230,7 +229,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
@model_validator(mode="before")
@classmethod
def validate_schema(cls, values: dict) -> Any:
"""Validate the pydantic schema.
"""Validate the Pydantic schema.
Args:
values: The values to validate.
@@ -239,7 +238,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
The validated values.
Raises:
ValueError: If the schema is not a pydantic schema.
`ValueError`: If the schema is not a Pydantic schema.
"""
schema = values["pydantic_schema"]
if "args_only" not in values:
@@ -262,10 +261,10 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
partial: Whether to parse partial JSON objects.
Raises:
ValueError: If the pydantic schema is not valid.
`ValueError`: If the Pydantic schema is not valid.
Returns:
The parsed JSON object.
@@ -288,13 +287,13 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
elif issubclass(pydantic_schema, BaseModelV1):
pydantic_args = pydantic_schema.parse_raw(args)
else:
msg = f"Unsupported pydantic schema: {pydantic_schema}"
msg = f"Unsupported Pydantic schema: {pydantic_schema}"
raise ValueError(msg)
return pydantic_args
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
"""Parse an output as an attribute of a pydantic object."""
"""Parse an output as an attribute of a Pydantic object."""
attr_name: str
"""The name of the attribute to return."""
@@ -305,7 +304,7 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
partial: Whether to parse partial JSON objects.
Returns:
The parsed JSON object.

View File

@@ -17,10 +17,10 @@ from langchain_core.utils.pydantic import (
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
"""Parse an output using a pydantic model."""
"""Parse an output using a Pydantic model."""
pydantic_object: Annotated[type[TBaseModel], SkipValidation()]
"""The pydantic model to parse."""
"""The Pydantic model to parse."""
def _parse_obj(self, obj: dict) -> TBaseModel:
try:
@@ -45,21 +45,20 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
def parse_result(
self, result: list[Generation], *, partial: bool = False
) -> TBaseModel | None:
"""Parse the result of an LLM call to a pydantic object.
"""Parse the result of an LLM call to a Pydantic object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects.
If `True`, the output will be a JSON object containing
all the keys that have been returned so far.
Defaults to `False`.
Raises:
OutputParserException: If the result is not valid JSON
or does not conform to the pydantic model.
`OutputParserException`: If the result is not valid JSON
or does not conform to the Pydantic model.
Returns:
The parsed pydantic object.
The parsed Pydantic object.
"""
try:
json_object = super().parse_result(result)
@@ -70,13 +69,13 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
raise
def parse(self, text: str) -> TBaseModel:
"""Parse the output of an LLM call to a pydantic object.
"""Parse the output of an LLM call to a Pydantic object.
Args:
text: The output of the LLM call.
Returns:
The parsed pydantic object.
The parsed Pydantic object.
"""
return super().parse(text)
@@ -107,7 +106,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
@property
@override
def OutputType(self) -> type[TBaseModel]:
"""Return the pydantic model."""
"""Return the Pydantic model."""
return self.pydantic_object

View File

@@ -97,7 +97,7 @@ class LLMResult(BaseModel):
other: Another `LLMResult` object to compare against.
Returns:
True if the generations and `llm_output` are equal, False otherwise.
`True` if the generations and `llm_output` are equal, `False` otherwise.
"""
if not isinstance(other, LLMResult):
return NotImplemented

View File

@@ -44,7 +44,7 @@ class BaseRateLimiter(abc.ABC):
the attempt. Defaults to `True`.
Returns:
True if the tokens were successfully acquired, False otherwise.
`True` if the tokens were successfully acquired, `False` otherwise.
"""
@abc.abstractmethod
@@ -63,7 +63,7 @@ class BaseRateLimiter(abc.ABC):
the attempt. Defaults to `True`.
Returns:
True if the tokens were successfully acquired, False otherwise.
`True` if the tokens were successfully acquired, `False` otherwise.
"""
@@ -210,7 +210,7 @@ class InMemoryRateLimiter(BaseRateLimiter):
the attempt. Defaults to `True`.
Returns:
True if the tokens were successfully acquired, False otherwise.
`True` if the tokens were successfully acquired, `False` otherwise.
"""
if not blocking:
return self._consume()
@@ -234,7 +234,7 @@ class InMemoryRateLimiter(BaseRateLimiter):
the attempt. Defaults to `True`.
Returns:
True if the tokens were successfully acquired, False otherwise.
`True` if the tokens were successfully acquired, `False` otherwise.
"""
if not blocking:
return self._consume()

View File

@@ -7,7 +7,6 @@ the backbone of a retriever, but there are other types of retrievers as well.
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any
@@ -15,8 +14,6 @@ from typing import TYPE_CHECKING, Any
from pydantic import ConfigDict
from typing_extensions import Self, TypedDict, override
from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
from langchain_core.documents import Document
from langchain_core.runnables import (
@@ -138,35 +135,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
@override
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
# Version upgrade for old retrievers that implemented the public
# methods directly.
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
warnings.warn(
"Retrievers must implement abstract `_get_relevant_documents` method"
" instead of `get_relevant_documents`",
DeprecationWarning,
stacklevel=4,
)
swap = cls.get_relevant_documents
cls.get_relevant_documents = ( # type: ignore[method-assign]
BaseRetriever.get_relevant_documents
)
cls._get_relevant_documents = swap # type: ignore[method-assign]
if (
hasattr(cls, "aget_relevant_documents")
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
):
warnings.warn(
"Retrievers must implement abstract `_aget_relevant_documents` method"
" instead of `aget_relevant_documents`",
DeprecationWarning,
stacklevel=4,
)
aswap = cls.aget_relevant_documents
cls.aget_relevant_documents = ( # type: ignore[method-assign]
BaseRetriever.aget_relevant_documents
)
cls._aget_relevant_documents = aswap # type: ignore[method-assign]
parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
if (
@@ -348,91 +316,3 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
query,
run_manager=run_manager.get_sync(),
)
@deprecated(since="0.1.46", alternative="invoke", removal="1.0")
def get_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
run_name: str | None = None,
**kwargs: Any,
) -> list[Document]:
"""Retrieve documents relevant to a query.
Users should favor using `.invoke` or `.batch` rather than
`get_relevant_documents directly`.
Args:
query: string to find relevant documents for.
callbacks: Callback manager or list of callbacks.
tags: Optional list of tags associated with the retriever.
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever.
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
run_name: Optional name for the run.
**kwargs: Additional arguments to pass to the retriever.
Returns:
List of relevant documents.
"""
config: RunnableConfig = {}
if callbacks:
config["callbacks"] = callbacks
if tags:
config["tags"] = tags
if metadata:
config["metadata"] = metadata
if run_name:
config["run_name"] = run_name
return self.invoke(query, config, **kwargs)
@deprecated(since="0.1.46", alternative="ainvoke", removal="1.0")
async def aget_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
run_name: str | None = None,
**kwargs: Any,
) -> list[Document]:
"""Asynchronously get documents relevant to a query.
Users should favor using `.ainvoke` or `.abatch` rather than
`aget_relevant_documents directly`.
Args:
query: string to find relevant documents for.
callbacks: Callback manager or list of callbacks.
tags: Optional list of tags associated with the retriever.
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever.
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
run_name: Optional name for the run.
**kwargs: Additional arguments to pass to the retriever.
Returns:
List of relevant documents.
"""
config: RunnableConfig = {}
if callbacks:
config["callbacks"] = callbacks
if tags:
config["tags"] = tags
if metadata:
config["metadata"] = metadata
if run_name:
config["run_name"] = run_name
return await self.ainvoke(query, config, **kwargs)

View File

@@ -304,7 +304,7 @@ class Runnable(ABC, Generic[Input, Output]):
TypeError: If the input type cannot be inferred.
"""
# First loop through all parent classes and if any of them is
# a pydantic model, we will pick up the generic parameterization
# a Pydantic model, we will pick up the generic parameterization
# from that model via the __pydantic_generic_metadata__ attribute.
for base in self.__class__.mro():
if hasattr(base, "__pydantic_generic_metadata__"):
@@ -312,7 +312,7 @@ class Runnable(ABC, Generic[Input, Output]):
if "args" in metadata and len(metadata["args"]) == 2:
return metadata["args"][0]
# If we didn't find a pydantic model in the parent classes,
# If we didn't find a Pydantic model in the parent classes,
# then loop through __orig_bases__. This corresponds to
# Runnables that are not pydantic models.
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
@@ -390,7 +390,7 @@ class Runnable(ABC, Generic[Input, Output]):
self.get_name("Input"),
root=root_type,
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# able to construct the Pydantic model.
# When we create the model, we pass information about the namespace
# where the model is being created, so the type annotations can
# be resolved correctly as well.
@@ -433,7 +433,7 @@ class Runnable(ABC, Generic[Input, Output]):
def output_schema(self) -> type[BaseModel]:
"""Output schema.
The type of output this `Runnable` produces specified as a pydantic model.
The type of output this `Runnable` produces specified as a Pydantic model.
"""
return self.get_output_schema()
@@ -468,7 +468,7 @@ class Runnable(ABC, Generic[Input, Output]):
self.get_name("Output"),
root=root_type,
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# able to construct the Pydantic model.
# When we create the model, we pass information about the namespace
# where the model is being created, so the type annotations can
# be resolved correctly as well.
@@ -776,11 +776,11 @@ class Runnable(ABC, Generic[Input, Output]):
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeStreamingListLLM(responses=["foo-lish"])
model = FakeStreamingListLLM(responses=["foo-lish"])
chain: Runnable = prompt | llm | {"str": StrOutputParser()}
chain: Runnable = prompt | model | {"str": StrOutputParser()}
chain_with_assign = chain.assign(hello=itemgetter("str") | llm)
chain_with_assign = chain.assign(hello=itemgetter("str") | model)
print(chain_with_assign.input_schema.model_json_schema())
# {'title': 'PromptInput', 'type': 'object', 'properties':
@@ -1273,22 +1273,20 @@ class Runnable(ABC, Generic[Input, Output]):
A `StreamEvent` is a dictionary with the following schema:
- `event`: **str** - Event names are of the format:
- `event`: Event names are of the format:
`on_[runnable_type]_(start|stream|end)`.
- `name`: **str** - The name of the `Runnable` that generated the event.
- `run_id`: **str** - randomly generated ID associated with the given
execution of the `Runnable` that emitted the event. A child `Runnable` that gets
invoked as part of the execution of a parent `Runnable` is assigned its own
unique ID.
- `parent_ids`: **list[str]** - The IDs of the parent runnables that generated
the event. The root `Runnable` will have an empty list. The order of the parent
IDs is from the root to the immediate parent. Only available for v2 version of
the API. The v1 version of the API will return an empty list.
- `tags`: **list[str] | None** - The tags of the `Runnable` that generated
the event.
- `metadata`: **dict[str, Any] | None** - The metadata of the `Runnable` that
generated the event.
- `data`: **dict[str, Any]**
- `name`: The name of the `Runnable` that generated the event.
- `run_id`: Randomly generated ID associated with the given execution of the
`Runnable` that emitted the event. A child `Runnable` that gets invoked as
part of the execution of a parent `Runnable` is assigned its own unique ID.
- `parent_ids`: The IDs of the parent runnables that generated the event. The
root `Runnable` will have an empty list. The order of the parent IDs is from
the root to the immediate parent. Only available for v2 version of the API.
The v1 version of the API will return an empty list.
- `tags`: The tags of the `Runnable` that generated the event.
- `metadata`: The metadata of the `Runnable` that generated the event.
- `data`: The data associated with the event. The contents of this field
depend on the type of event. See the table below for more details.
Below is a table that illustrates some events that might be emitted by various
chains. Metadata fields have been omitted from the table for brevity.
@@ -1297,39 +1295,23 @@ class Runnable(ABC, Generic[Input, Output]):
!!! note
This reference table is for the v2 version of the schema.
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| event | name | chunk | input | output |
+==========================+==================+=====================================+===================================================+=====================================================+
| `on_chat_model_start` | [model name] | | `{"messages": [[SystemMessage, HumanMessage]]}` | |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_chat_model_stream` | [model name] | `AIMessageChunk(content="hello")` | | |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_chat_model_end` | [model name] | | `{"messages": [[SystemMessage, HumanMessage]]}` | `AIMessageChunk(content="hello world")` |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_llm_start` | [model name] | | `{'input': 'hello'}` | |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_llm_stream` | [model name] | `'Hello' ` | | |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_llm_end` | [model name] | | `'Hello human!'` | |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_chain_start` | format_docs | | | |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_chain_stream` | format_docs | `'hello world!, goodbye world!'` | | |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_chain_end` | format_docs | | `[Document(...)]` | `'hello world!, goodbye world!'` |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_tool_start` | some_tool | | `{"x": 1, "y": "2"}` | |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_tool_end` | some_tool | | | `{"x": 1, "y": "2"}` |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_retriever_start` | [retriever name] | | `{"query": "hello"}` | |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_retriever_end` | [retriever name] | | `{"query": "hello"}` | `[Document(...), ..]` |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_prompt_start` | [template_name] | | `{"question": "hello"}` | |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| `on_prompt_end` | [template_name] | | `{"question": "hello"}` | `ChatPromptValue(messages: [SystemMessage, ...])` |
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
| event | name | chunk | input | output |
| ---------------------- | -------------------- | ----------------------------------- | ------------------------------------------------- | --------------------------------------------------- |
| `on_chat_model_start` | `'[model name]'` | | `{"messages": [[SystemMessage, HumanMessage]]}` | |
| `on_chat_model_stream` | `'[model name]'` | `AIMessageChunk(content="hello")` | | |
| `on_chat_model_end` | `'[model name]'` | | `{"messages": [[SystemMessage, HumanMessage]]}` | `AIMessageChunk(content="hello world")` |
| `on_llm_start` | `'[model name]'` | | `{'input': 'hello'}` | |
| `on_llm_stream` | `'[model name]'` | `'Hello' ` | | |
| `on_llm_end` | `'[model name]'` | | `'Hello human!'` | |
| `on_chain_start` | `'format_docs'` | | | |
| `on_chain_stream` | `'format_docs'` | `'hello world!, goodbye world!'` | | |
| `on_chain_end` | `'format_docs'` | | `[Document(...)]` | `'hello world!, goodbye world!'` |
| `on_tool_start` | `'some_tool'` | | `{"x": 1, "y": "2"}` | |
| `on_tool_end` | `'some_tool'` | | | `{"x": 1, "y": "2"}` |
| `on_retriever_start` | `'[retriever name]'` | | `{"query": "hello"}` | |
| `on_retriever_end` | `'[retriever name]'` | | `{"query": "hello"}` | `[Document(...), ..]` |
| `on_prompt_start` | `'[template_name]'` | | `{"question": "hello"}` | |
| `on_prompt_end` | `'[template_name]'` | | `{"question": "hello"}` | `ChatPromptValue(messages: [SystemMessage, ...])` |
In addition to the standard events, users can also dispatch custom events (see example below).
@@ -1337,13 +1319,10 @@ class Runnable(ABC, Generic[Input, Output]):
A custom event has following format:
+-----------+------+-----------------------------------------------------------------------------------------------------------+
| Attribute | Type | Description |
+===========+======+===========================================================================================================+
| name | str | A user defined name for the event. |
+-----------+------+-----------------------------------------------------------------------------------------------------------+
| data | Any | The data associated with the event. This can be anything, though we suggest making it JSON serializable. |
+-----------+------+-----------------------------------------------------------------------------------------------------------+
| Attribute | Type | Description |
| ----------- | ------ | --------------------------------------------------------------------------------------------------------- |
| `name` | `str` | A user defined name for the event. |
| `data` | `Any` | The data associated with the event. This can be anything, though we suggest making it JSON serializable. |
Here are declarations associated with the standard events shown above:
@@ -1619,16 +1598,16 @@ class Runnable(ABC, Generic[Input, Output]):
from langchain_ollama import ChatOllama
from langchain_core.output_parsers import StrOutputParser
llm = ChatOllama(model="llama3.1")
model = ChatOllama(model="llama3.1")
# Without bind
chain = llm | StrOutputParser()
chain = model | StrOutputParser()
chain.invoke("Repeat quoted words exactly: 'One two three four five.'")
# Output is 'One two three four five.'
# With bind
chain = llm.bind(stop=["three"]) | StrOutputParser()
chain = model.bind(stop=["three"]) | StrOutputParser()
chain.invoke("Repeat quoted words exactly: 'One two three four five.'")
# Output is 'One two'
@@ -4493,7 +4472,7 @@ class RunnableLambda(Runnable[Input, Output]):
@override
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
"""The pydantic schema for the input to this `Runnable`.
"""The Pydantic schema for the input to this `Runnable`.
Args:
config: The config to use.
@@ -5127,7 +5106,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
None,
),
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# able to construct the Pydantic model.
# When we create the model, we pass information about the namespace
# where the model is being created, so the type annotations can
# be resolved correctly as well.
@@ -5150,7 +5129,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
self.get_name("Output"),
root=list[schema], # type: ignore[valid-type]
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# able to construct the Pydantic model.
# When we create the model, we pass information about the namespace
# where the model is being created, so the type annotations can
# be resolved correctly as well.
@@ -5387,13 +5366,13 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[
custom_input_type: Any | None = None
"""Override the input type of the underlying `Runnable` with a custom type.
The type can be a pydantic model, or a type annotation (e.g., `list[str]`).
The type can be a Pydantic model, or a type annotation (e.g., `list[str]`).
"""
# Union[Type[Output], BaseModel] + things like list[str]
custom_output_type: Any | None = None
"""Override the output type of the underlying `Runnable` with a custom type.
The type can be a pydantic model, or a type annotation (e.g., `list[str]`).
The type can be a Pydantic model, or a type annotation (e.g., `list[str]`).
"""
model_config = ConfigDict(
@@ -6077,10 +6056,10 @@ def chain(
@chain
def my_func(fields):
prompt = PromptTemplate("Hello, {name}!")
llm = OpenAI()
model = OpenAI()
formatted = prompt.invoke(**fields)
for chunk in llm.stream(formatted):
for chunk in model.stream(formatted):
yield chunk
```
"""

View File

@@ -594,7 +594,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
Returns:
If the attribute is anything other than a method that outputs a Runnable,
returns getattr(self.runnable, name). If the attribute is a method that
does return a new Runnable (e.g. llm.bind_tools([...]) outputs a new
does return a new Runnable (e.g. model.bind_tools([...]) outputs a new
RunnableBinding) then self.runnable and each of the runnables in
self.fallbacks is replaced with getattr(x, name).
@@ -605,15 +605,15 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
gpt_4o = ChatOpenAI(model="gpt-4o")
claude_3_sonnet = ChatAnthropic(model="claude-3-7-sonnet-20250219")
llm = gpt_4o.with_fallbacks([claude_3_sonnet])
model = gpt_4o.with_fallbacks([claude_3_sonnet])
llm.model_name
model.model_name
# -> "gpt-4o"
# .bind_tools() is called on both ChatOpenAI and ChatAnthropic
# Equivalent to:
# gpt_4o.bind_tools([...]).with_fallbacks([claude_3_sonnet.bind_tools([...])])
llm.bind_tools([...])
model.bind_tools([...])
# -> RunnableWithFallbacks(
runnable=RunnableBinding(bound=ChatOpenAI(...), kwargs={"tools": [...]}),
fallbacks=[RunnableBinding(bound=ChatAnthropic(...), kwargs={"tools": [...]})],

View File

@@ -52,7 +52,7 @@ def is_uuid(value: str) -> bool:
value: The string to check.
Returns:
True if the string is a valid UUID, False otherwise.
`True` if the string is a valid UUID, `False` otherwise.
"""
try:
UUID(value)

View File

@@ -1,4 +1,4 @@
"""Module contains typedefs that are used with Runnables."""
"""Module contains typedefs that are used with `Runnable` objects."""
from __future__ import annotations
@@ -14,43 +14,43 @@ class EventData(TypedDict, total=False):
"""Data associated with a streaming event."""
input: Any
"""The input passed to the Runnable that generated the event.
"""The input passed to the `Runnable` that generated the event.
Inputs will sometimes be available at the *START* of the Runnable, and
sometimes at the *END* of the Runnable.
Inputs will sometimes be available at the *START* of the `Runnable`, and
sometimes at the *END* of the `Runnable`.
If a Runnable is able to stream its inputs, then its input by definition
won't be known until the *END* of the Runnable when it has finished streaming
If a `Runnable` is able to stream its inputs, then its input by definition
won't be known until the *END* of the `Runnable` when it has finished streaming
its inputs.
"""
error: NotRequired[BaseException]
"""The error that occurred during the execution of the Runnable.
"""The error that occurred during the execution of the `Runnable`.
This field is only available if the Runnable raised an exception.
This field is only available if the `Runnable` raised an exception.
!!! version-added "Added in version 1.0.0"
"""
output: Any
"""The output of the Runnable that generated the event.
"""The output of the `Runnable` that generated the event.
Outputs will only be available at the *END* of the Runnable.
Outputs will only be available at the *END* of the `Runnable`.
For most Runnables, this field can be inferred from the `chunk` field,
though there might be some exceptions for special cased Runnables (e.g., like
For most `Runnable` objects, this field can be inferred from the `chunk` field,
though there might be some exceptions for special a cased `Runnable` (e.g., like
chat models), which may return more information.
"""
chunk: Any
"""A streaming chunk from the output that generated the event.
chunks support addition in general, and adding them up should result
in the output of the Runnable that generated the event.
in the output of the `Runnable` that generated the event.
"""
class BaseStreamEvent(TypedDict):
"""Streaming event.
Schema of a streaming event which is produced from the astream_events method.
Schema of a streaming event which is produced from the `astream_events` method.
Example:
```python
@@ -94,45 +94,45 @@ class BaseStreamEvent(TypedDict):
"""
event: str
"""Event names are of the format: on_[runnable_type]_(start|stream|end).
"""Event names are of the format: `on_[runnable_type]_(start|stream|end)`.
Runnable types are one of:
- **llm** - used by non chat models
- **chat_model** - used by chat models
- **prompt** -- e.g., ChatPromptTemplate
- **tool** -- from tools defined via @tool decorator or inheriting
from Tool/BaseTool
- **chain** - most Runnables are of this type
- **prompt** -- e.g., `ChatPromptTemplate`
- **tool** -- from tools defined via `@tool` decorator or inheriting
from `Tool`/`BaseTool`
- **chain** - most `Runnable` objects are of this type
Further, the events are categorized as one of:
- **start** - when the Runnable starts
- **stream** - when the Runnable is streaming
- **end* - when the Runnable ends
- **start** - when the `Runnable` starts
- **stream** - when the `Runnable` is streaming
- **end* - when the `Runnable` ends
start, stream and end are associated with slightly different `data` payload.
Please see the documentation for `EventData` for more details.
"""
run_id: str
"""An randomly generated ID to keep track of the execution of the given Runnable.
"""An randomly generated ID to keep track of the execution of the given `Runnable`.
Each child Runnable that gets invoked as part of the execution of a parent Runnable
is assigned its own unique ID.
Each child `Runnable` that gets invoked as part of the execution of a parent
`Runnable` is assigned its own unique ID.
"""
tags: NotRequired[list[str]]
"""Tags associated with the Runnable that generated this event.
"""Tags associated with the `Runnable` that generated this event.
Tags are always inherited from parent Runnables.
Tags are always inherited from parent `Runnable` objects.
Tags can either be bound to a Runnable using `.with_config({"tags": ["hello"]})`
Tags can either be bound to a `Runnable` using `.with_config({"tags": ["hello"]})`
or passed at run time using `.astream_events(..., {"tags": ["hello"]})`.
"""
metadata: NotRequired[dict[str, Any]]
"""Metadata associated with the Runnable that generated this event.
"""Metadata associated with the `Runnable` that generated this event.
Metadata can either be bound to a Runnable using
Metadata can either be bound to a `Runnable` using
`.with_config({"metadata": { "foo": "bar" }})`
@@ -146,8 +146,8 @@ class BaseStreamEvent(TypedDict):
Root Events will have an empty list.
For example, if a Runnable A calls Runnable B, then the event generated by Runnable
B will have Runnable A's ID in the parent_ids field.
For example, if a `Runnable` A calls `Runnable` B, then the event generated by
`Runnable` B will have `Runnable` A's ID in the `parent_ids` field.
The order of the parent IDs is from the root parent to the immediate parent.
@@ -164,7 +164,7 @@ class StandardStreamEvent(BaseStreamEvent):
The contents of the event data depend on the event type.
"""
name: str
"""The name of the Runnable that generated the event."""
"""The name of the `Runnable` that generated the event."""
class CustomStreamEvent(BaseStreamEvent):

View File

@@ -80,7 +80,7 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool: # noqa: A002
callable: The callable to check.
Returns:
True if the callable accepts a run_manager argument, False otherwise.
`True` if the callable accepts a run_manager argument, `False` otherwise.
"""
try:
return signature(callable).parameters.get("run_manager") is not None
@@ -95,7 +95,7 @@ def accepts_config(callable: Callable[..., Any]) -> bool: # noqa: A002
callable: The callable to check.
Returns:
True if the callable accepts a config argument, False otherwise.
`True` if the callable accepts a config argument, `False` otherwise.
"""
try:
return signature(callable).parameters.get("config") is not None
@@ -110,7 +110,7 @@ def accepts_context(callable: Callable[..., Any]) -> bool: # noqa: A002
callable: The callable to check.
Returns:
True if the callable accepts a context argument, False otherwise.
`True` if the callable accepts a context argument, `False` otherwise.
"""
try:
return signature(callable).parameters.get("context") is not None
@@ -123,7 +123,7 @@ def asyncio_accepts_context() -> bool:
"""Cache the result of checking if asyncio.create_task accepts a `context` arg.
Returns:
True if `asyncio.create_task` accepts a context argument, False otherwise.
True if `asyncio.create_task` accepts a context argument, `False` otherwise.
"""
return accepts_context(asyncio.create_task)
@@ -727,7 +727,7 @@ def is_async_generator(
func: The function to check.
Returns:
True if the function is an async generator, False otherwise.
`True` if the function is an async generator, `False` otherwise.
"""
return inspect.isasyncgenfunction(func) or (
hasattr(func, "__call__") # noqa: B004
@@ -744,7 +744,7 @@ def is_async_callable(
func: The function to check.
Returns:
True if the function is async, False otherwise.
`True` if the function is async, `False` otherwise.
"""
return asyncio.iscoroutinefunction(func) or (
hasattr(func, "__call__") # noqa: B004

View File

@@ -92,7 +92,7 @@ def _is_annotated_type(typ: type[Any]) -> bool:
typ: The type to check.
Returns:
True if the type is an Annotated type, False otherwise.
`True` if the type is an Annotated type, `False` otherwise.
"""
return get_origin(typ) is typing.Annotated
@@ -226,7 +226,7 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo
pydantic_version: The Pydantic version to check against ("v1" or "v2").
Returns:
True if the annotation is a Pydantic model, False otherwise.
`True` if the annotation is a Pydantic model, `False` otherwise.
"""
base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
try:
@@ -245,7 +245,7 @@ def _function_annotations_are_pydantic_v1(
func: The function being checked.
Returns:
True if all Pydantic annotations are from V1, False otherwise.
True if all Pydantic annotations are from V1, `False` otherwise.
Raises:
NotImplementedError: If the function contains mixed V1 and V2 annotations.
@@ -285,17 +285,17 @@ def create_schema_from_function(
error_on_invalid_docstring: bool = False,
include_injected: bool = True,
) -> type[BaseModel]:
"""Create a pydantic schema from a function's signature.
"""Create a Pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydantic schema.
model_name: Name to assign to the generated Pydantic schema.
func: Function to generate the schema from.
filter_args: Optional list of arguments to exclude from the schema.
Defaults to FILTERED_ARGS.
Defaults to `FILTERED_ARGS`.
parse_docstring: Whether to parse the function's docstring for descriptions
for each argument. Defaults to `False`.
error_on_invalid_docstring: if `parse_docstring` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
whether to raise `ValueError` on invalid Google Style docstrings.
Defaults to `False`.
include_injected: Whether to include injected arguments in the schema.
Defaults to `True`, since we want to include them in the schema
@@ -312,7 +312,7 @@ def create_schema_from_function(
# https://docs.pydantic.dev/latest/usage/validation_decorator/
with warnings.catch_warnings():
# We are using deprecated functionality here.
# This code should be re-written to simply construct a pydantic model
# This code should be re-written to simply construct a Pydantic model
# using inspect.signature and create_model.
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
@@ -517,7 +517,7 @@ class ChildTool(BaseTool):
"""Check if the tool accepts only a single input argument.
Returns:
True if the tool has only one input argument, False otherwise.
`True` if the tool has only one input argument, `False` otherwise.
"""
keys = {k for k in self.args if k != "kwargs"}
return len(keys) == 1
@@ -981,7 +981,7 @@ def _is_tool_call(x: Any) -> bool:
x: The input to check.
Returns:
True if the input is a tool call, False otherwise.
`True` if the input is a tool call, `False` otherwise.
"""
return isinstance(x, dict) and x.get("type") == "tool_call"
@@ -1128,7 +1128,7 @@ def _is_message_content_type(obj: Any) -> bool:
obj: The object to check.
Returns:
True if the object is valid message content, False otherwise.
`True` if the object is valid message content, `False` otherwise.
"""
return isinstance(obj, str) or (
isinstance(obj, list) and all(_is_message_content_block(e) for e in obj)
@@ -1144,7 +1144,7 @@ def _is_message_content_block(obj: Any) -> bool:
obj: The object to check.
Returns:
True if the object is a valid content block, False otherwise.
`True` if the object is a valid content block, `False` otherwise.
"""
if isinstance(obj, str):
return True
@@ -1248,7 +1248,7 @@ def _is_injected_arg_type(
injected_type: The specific injected type to check for.
Returns:
True if the type is an injected argument, False otherwise.
`True` if the type is an injected argument, `False` otherwise.
"""
injected_type = injected_type or InjectedToolArg
return any(

View File

@@ -69,7 +69,7 @@ class Tool(BaseTool):
def _to_args_and_kwargs(
self, tool_input: str | dict, tool_call_id: str | None
) -> tuple[tuple, dict]:
"""Convert tool input to pydantic model.
"""Convert tool input to Pydantic model.
Args:
tool_input: The input to the tool.
@@ -79,8 +79,7 @@ class Tool(BaseTool):
ToolException: If the tool input is invalid.
Returns:
the pydantic model args and kwargs.
The Pydantic model args and kwargs.
"""
args, kwargs = super()._to_args_and_kwargs(tool_input, tool_call_id)
# For backwards compatibility. The tool must be run with a single input

View File

@@ -190,7 +190,7 @@ class RunLog(RunLogPatch):
other: The other `RunLog` to compare to.
Returns:
True if the `RunLog`s are equal, False otherwise.
`True` if the `RunLog`s are equal, `False` otherwise.
"""
# First compare that the state is the same
if not isinstance(other, RunLog):
@@ -288,7 +288,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
*ops: The operations to send to the stream.
Returns:
True if the patch was sent successfully, False if the stream is closed.
`True` if the patch was sent successfully, False if the stream is closed.
"""
# We will likely want to wrap this in try / except at some point
# to handle exceptions that might arise at run time.
@@ -368,7 +368,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
run: The Run to check.
Returns:
True if the run should be included, False otherwise.
`True` if the run should be included, `False` otherwise.
"""
if run.id == self.root_id:
return False

View File

@@ -13,7 +13,7 @@ def env_var_is_set(env_var: str) -> bool:
env_var: The name of the environment variable.
Returns:
True if the environment variable is set, False otherwise.
`True` if the environment variable is set, `False` otherwise.
"""
return env_var in os.environ and os.environ[env_var] not in {
"",

View File

@@ -713,7 +713,7 @@ def tool_example_to_messages(
"type": "function",
"function": {
# The name of the function right now corresponds to the name
# of the pydantic model. This is implicit in the API right now,
# of the Pydantic model. This is implicit in the API right now,
# and will be improved over time.
"name": tool_call.__class__.__name__,
"arguments": tool_call.model_dump_json(),

View File

@@ -7,6 +7,6 @@ def is_interactive_env() -> bool:
"""Determine if running within IPython or Jupyter.
Returns:
True if running in an interactive environment, False otherwise.
True if running in an interactive environment, `False` otherwise.
"""
return hasattr(sys, "ps2")

View File

@@ -78,7 +78,7 @@ def is_pydantic_v1_subclass(cls: type) -> bool:
"""Check if the given class is Pydantic v1-like.
Returns:
True if the given class is a subclass of Pydantic `BaseModel` 1.x.
`True` if the given class is a subclass of Pydantic `BaseModel` 1.x.
"""
return issubclass(cls, BaseModelV1)
@@ -87,7 +87,7 @@ def is_pydantic_v2_subclass(cls: type) -> bool:
"""Check if the given class is Pydantic v2-like.
Returns:
True if the given class is a subclass of Pydantic BaseModel 2.x.
`True` if the given class is a subclass of Pydantic BaseModel 2.x.
"""
return issubclass(cls, BaseModel)
@@ -101,7 +101,7 @@ def is_basemodel_subclass(cls: type) -> bool:
* pydantic.v1.BaseModel in Pydantic 2.x
Returns:
True if the given class is a subclass of Pydantic `BaseModel`.
`True` if the given class is a subclass of Pydantic `BaseModel`.
"""
# Before we can use issubclass on the cls we need to check if it is a class
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
@@ -119,7 +119,7 @@ def is_basemodel_instance(obj: Any) -> bool:
* pydantic.v1.BaseModel in Pydantic 2.x
Returns:
True if the given class is an instance of Pydantic `BaseModel`.
`True` if the given class is an instance of Pydantic `BaseModel`.
"""
return isinstance(obj, (BaseModel, BaseModelV1))
@@ -206,7 +206,7 @@ def _create_subset_model_v1(
descriptions: dict | None = None,
fn_description: str | None = None,
) -> type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
"""Create a Pydantic model with only a subset of model's fields."""
fields = {}
for field_name in field_names:
@@ -235,7 +235,7 @@ def _create_subset_model_v2(
descriptions: dict | None = None,
fn_description: str | None = None,
) -> type[BaseModel]:
"""Create a pydantic model with a subset of the model fields."""
"""Create a Pydantic model with a subset of the model fields."""
descriptions_ = descriptions or {}
fields = {}
for field_name in field_names:
@@ -438,9 +438,9 @@ def create_model(
/,
**field_definitions: Any,
) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions.
"""Create a Pydantic model with the given field definitions.
Please use create_model_v2 instead of this function.
Please use `create_model_v2` instead of this function.
Args:
model_name: The name of the model.
@@ -511,7 +511,7 @@ def create_model_v2(
field_definitions: dict[str, Any] | None = None,
root: Any | None = None,
) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions.
"""Create a Pydantic model with the given field definitions.
Attention:
Please do not use outside of langchain packages. This API
@@ -522,7 +522,7 @@ def create_model_v2(
module_name: The name of the module where the model is defined.
This is used by Pydantic to resolve any forward references.
field_definitions: The field definitions for the model.
root: Type for a root model (RootModel)
root: Type for a root model (`RootModel`)
Returns:
The created model.

View File

@@ -57,7 +57,7 @@ def _content_blocks_equal_ignore_id(
expected: Expected content to compare against (string or list of blocks).
Returns:
True if content matches (excluding `id` fields), False otherwise.
True if content matches (excluding `id` fields), `False` otherwise.
"""
if isinstance(actual, str) or isinstance(expected, str):

View File

@@ -1934,11 +1934,10 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
def test_args_schema_explicitly_typed() -> None:
"""This should test that one can type the args schema as a pydantic model.
Please note that this will test using pydantic 2 even though BaseTool
is a pydantic 1 model!
"""This should test that one can type the args schema as a Pydantic model.
Please note that this will test using pydantic 2 even though `BaseTool`
is a Pydantic 1 model!
"""
class Foo(BaseModel):
@@ -1981,7 +1980,7 @@ def test_args_schema_explicitly_typed() -> None:
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
"""This should test that one can type the args schema as a pydantic model."""
"""This should test that one can type the args schema as a Pydantic model."""
def foo(a: int, b: str) -> str:
"""Hahaha."""

View File

@@ -7,9 +7,9 @@ from langchain_core.tracers.context import collect_runs
def test_collect_runs() -> None:
llm = FakeListLLM(responses=["hello"])
model = FakeListLLM(responses=["hello"])
with collect_runs() as cb:
llm.invoke("hi")
model.invoke("hi")
assert cb.traced_runs
assert len(cb.traced_runs) == 1
assert isinstance(cb.traced_runs[0].id, uuid.UUID)

View File

@@ -124,7 +124,7 @@ class BaseSingleActionAgent(BaseModel):
along with observations.
Returns:
AgentFinish: Agent finish object.
Agent finish object.
Raises:
ValueError: If `early_stopping_method` is not supported.
@@ -155,7 +155,7 @@ class BaseSingleActionAgent(BaseModel):
kwargs: Additional arguments.
Returns:
BaseSingleActionAgent: Agent object.
Agent object.
"""
raise NotImplementedError
@@ -169,7 +169,7 @@ class BaseSingleActionAgent(BaseModel):
"""Return dictionary representation of agent.
Returns:
Dict: Dictionary representation of agent.
Dictionary representation of agent.
"""
_dict = super().model_dump()
try:
@@ -233,7 +233,7 @@ class BaseMultiActionAgent(BaseModel):
"""Get allowed tools.
Returns:
list[str] | None: Allowed tools.
Allowed tools.
"""
return None
@@ -297,7 +297,7 @@ class BaseMultiActionAgent(BaseModel):
along with observations.
Returns:
AgentFinish: Agent finish object.
Agent finish object.
Raises:
ValueError: If `early_stopping_method` is not supported.
@@ -388,8 +388,7 @@ class MultiActionAgentOutputParser(
text: Text to parse.
Returns:
Union[List[AgentAction], AgentFinish]:
List of agent actions or agent finish.
List of agent actions or agent finish.
"""
@@ -812,7 +811,7 @@ class Agent(BaseSingleActionAgent):
**kwargs: User inputs.
Returns:
Dict[str, Any]: Full inputs for the LLMChain.
Full inputs for the LLMChain.
"""
thoughts = self._construct_scratchpad(intermediate_steps)
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
@@ -834,7 +833,7 @@ class Agent(BaseSingleActionAgent):
values: Values to validate.
Returns:
Dict: Validated values.
Validated values.
Raises:
ValueError: If `agent_scratchpad` is not in prompt.input_variables
@@ -875,7 +874,7 @@ class Agent(BaseSingleActionAgent):
tools: Tools to use.
Returns:
BasePromptTemplate: Prompt template.
Prompt template.
"""
@classmethod
@@ -910,7 +909,7 @@ class Agent(BaseSingleActionAgent):
kwargs: Additional arguments.
Returns:
Agent: Agent object.
Agent object.
"""
cls._validate_tools(tools)
llm_chain = LLMChain(
@@ -942,7 +941,7 @@ class Agent(BaseSingleActionAgent):
**kwargs: User inputs.
Returns:
AgentFinish: Agent finish object.
Agent finish object.
Raises:
ValueError: If `early_stopping_method` is not in ['force', 'generate'].
@@ -1082,7 +1081,7 @@ class AgentExecutor(Chain):
kwargs: Additional arguments.
Returns:
AgentExecutor: Agent executor object.
Agent executor object.
"""
return cls(
agent=agent,
@@ -1099,7 +1098,7 @@ class AgentExecutor(Chain):
values: Values to validate.
Returns:
Dict: Validated values.
Validated values.
Raises:
ValueError: If allowed tools are different than provided tools.
@@ -1126,7 +1125,7 @@ class AgentExecutor(Chain):
values: Values to validate.
Returns:
Dict: Validated values.
Validated values.
"""
agent = values.get("agent")
if agent and isinstance(agent, Runnable):
@@ -1209,7 +1208,7 @@ class AgentExecutor(Chain):
async_: Whether to run async. (Ignored)
Returns:
AgentExecutorIterator: Agent executor iterator object.
Agent executor iterator object.
"""
return AgentExecutorIterator(
self,
@@ -1244,7 +1243,7 @@ class AgentExecutor(Chain):
name: Name of tool.
Returns:
BaseTool: Tool object.
Tool object.
"""
return {tool.name: tool for tool in self.tools}[name]
@@ -1759,7 +1758,7 @@ class AgentExecutor(Chain):
kwargs: Additional arguments.
Yields:
AddableDict: Addable dictionary.
Addable dictionary.
"""
config = ensure_config(config)
iterator = AgentExecutorIterator(
@@ -1790,7 +1789,7 @@ class AgentExecutor(Chain):
kwargs: Additional arguments.
Yields:
AddableDict: Addable dictionary.
Addable dictionary.
"""
config = ensure_config(config)
iterator = AgentExecutorIterator(

View File

@@ -58,7 +58,7 @@ def create_vectorstore_agent(
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langgraph.prebuilt import create_react_agent
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
vector_store = InMemoryVectorStore.from_texts(
[
@@ -74,7 +74,7 @@ def create_vectorstore_agent(
"Fetches information about pets.",
)
agent = create_react_agent(llm, [tool])
agent = create_react_agent(model, [tool])
for step in agent.stream(
{"messages": [("human", "What are dogs known for?")]},
@@ -156,7 +156,7 @@ def create_vectorstore_router_agent(
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langgraph.prebuilt import create_react_agent
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
pet_vector_store = InMemoryVectorStore.from_texts(
[
@@ -187,7 +187,7 @@ def create_vectorstore_router_agent(
),
]
agent = create_react_agent(llm, tools)
agent = create_react_agent(model, tools)
for step in agent.stream(
{"messages": [("human", "Tell me about carrots.")]},

View File

@@ -16,7 +16,7 @@ def format_log_to_str(
Defaults to "Thought: ".
Returns:
str: The scratchpad.
The scratchpad.
"""
thoughts = ""
for action, observation in intermediate_steps:

View File

@@ -14,7 +14,7 @@ def format_log_to_messages(
Defaults to "{observation}".
Returns:
List[BaseMessage]: The scratchpad.
The scratchpad.
"""
thoughts: list[BaseMessage] = []
for action, observation in intermediate_steps:

View File

@@ -42,7 +42,7 @@ def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool:
limit_to_domains: The allowed domains.
Returns:
True if the URL is in the allowed domains, False otherwise.
`True` if the URL is in the allowed domains, `False` otherwise.
"""
scheme, domain = _extract_scheme_and_domain(url)
@@ -143,7 +143,7 @@ try:
description: Limit the number of results
\"\"\"
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
toolkit = RequestsToolkit(
requests_wrapper=TextRequestsWrapper(headers={}), # no auth required
allow_dangerous_requests=ALLOW_DANGEROUS_REQUESTS,
@@ -152,7 +152,7 @@ try:
api_request_chain = (
API_URL_PROMPT.partial(api_docs=api_spec)
| llm.bind_tools(tools, tool_choice="any")
| model.bind_tools(tools, tool_choice="any")
)
class ChainState(TypedDict):
@@ -169,7 +169,7 @@ try:
return {"messages": [response]}
async def acall_model(state: ChainState, config: RunnableConfig):
response = await llm.ainvoke(state["messages"], config)
response = await model.ainvoke(state["messages"], config)
return {"messages": [response]}
graph_builder = StateGraph(ChainState)

View File

@@ -53,16 +53,16 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
input_variables=["page_content"], template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
model = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template("Summarize this content: {context}")
llm_chain = LLMChain(llm=llm, prompt=prompt)
llm_chain = LLMChain(llm=model, prompt=prompt)
# We now define how to combine these summaries
reduce_prompt = PromptTemplate.from_template(
"Combine these summaries: {context}"
)
reduce_llm_chain = LLMChain(llm=llm, prompt=reduce_prompt)
reduce_llm_chain = LLMChain(llm=model, prompt=reduce_prompt)
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_llm_chain,
document_prompt=document_prompt,
@@ -79,7 +79,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
# which is specifically aimed at collapsing documents BEFORE
# the final call.
prompt = PromptTemplate.from_template("Collapse this content: {context}")
llm_chain = LLMChain(llm=llm, prompt=prompt)
llm_chain = LLMChain(llm=model, prompt=prompt)
collapse_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,

View File

@@ -42,7 +42,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
from langchain_classic.output_parsers.regex import RegexParser
document_variable_name = "context"
llm = OpenAI()
model = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
# The actual prompt will need to be a lot more complex, this is just
@@ -61,7 +61,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
input_variables=["context"],
output_parser=output_parser,
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
llm_chain = LLMChain(llm=model, prompt=prompt)
chain = MapRerankDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,

View File

@@ -171,11 +171,11 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
input_variables=["page_content"], template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
model = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template("Summarize this content: {context}")
llm_chain = LLMChain(llm=llm, prompt=prompt)
llm_chain = LLMChain(llm=model, prompt=prompt)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
@@ -188,7 +188,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
# which is specifically aimed at collapsing documents BEFORE
# the final call.
prompt = PromptTemplate.from_template("Collapse this content: {context}")
llm_chain = LLMChain(llm=llm, prompt=prompt)
llm_chain = LLMChain(llm=model, prompt=prompt)
collapse_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,

View File

@@ -55,11 +55,11 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
input_variables=["page_content"], template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
model = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template("Summarize this content: {context}")
initial_llm_chain = LLMChain(llm=llm, prompt=prompt)
initial_llm_chain = LLMChain(llm=model, prompt=prompt)
initial_response_name = "prev_response"
# The prompt here should take as an input variable the
# `document_variable_name` as well as `initial_response_name`
@@ -67,7 +67,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
"Here's your first summary: {prev_response}. "
"Now add to it based on the following context: {context}"
)
refine_llm_chain = LLMChain(llm=llm, prompt=prompt_refine)
refine_llm_chain = LLMChain(llm=model, prompt=prompt_refine)
chain = RefineDocumentsChain(
initial_llm_chain=initial_llm_chain,
refine_llm_chain=refine_llm_chain,

View File

@@ -68,8 +68,8 @@ def create_stuff_documents_chain(
prompt = ChatPromptTemplate.from_messages(
[("system", "What are everyone's favorite colors:\n\n{context}")]
)
llm = ChatOpenAI(model="gpt-3.5-turbo")
chain = create_stuff_documents_chain(llm, prompt)
model = ChatOpenAI(model="gpt-3.5-turbo")
chain = create_stuff_documents_chain(model, prompt)
docs = [
Document(page_content="Jesse loves red but not yellow"),
@@ -132,11 +132,11 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
input_variables=["page_content"], template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
model = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template("Summarize this content: {context}")
llm_chain = LLMChain(llm=llm, prompt=prompt)
llm_chain = LLMChain(llm=model, prompt=prompt)
chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,

View File

@@ -58,7 +58,7 @@ class ConstitutionalChain(Chain):
from langgraph.graph import END, START, StateGraph
from typing_extensions import Annotated, TypedDict
llm = ChatOpenAI(model="gpt-4o-mini")
model = ChatOpenAI(model="gpt-4o-mini")
class Critique(TypedDict):
"""Generate a critique, if needed."""
@@ -86,9 +86,9 @@ class ConstitutionalChain(Chain):
"Revision Request: {revision_request}"
)
chain = llm | StrOutputParser()
critique_chain = critique_prompt | llm.with_structured_output(Critique)
revision_chain = revision_prompt | llm | StrOutputParser()
chain = model | StrOutputParser()
critique_chain = critique_prompt | model.with_structured_output(Critique)
revision_chain = revision_prompt | model | StrOutputParser()
class State(TypedDict):
@@ -170,16 +170,16 @@ class ConstitutionalChain(Chain):
from langchain_classic.chains.constitutional_ai.models \
import ConstitutionalPrinciple
llm = OpenAI()
llmodelm = OpenAI()
qa_prompt = PromptTemplate(
template="Q: {question} A:",
input_variables=["question"],
)
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
qa_chain = LLMChain(llm=model, prompt=qa_prompt)
constitutional_chain = ConstitutionalChain.from_llm(
llm=llm,
llm=model,
chain=qa_chain,
constitutional_principles=[
ConstitutionalPrinciple(

View File

@@ -47,9 +47,9 @@ class ConversationChain(LLMChain):
return store[session_id]
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
model = ChatOpenAI(model="gpt-3.5-turbo-0125")
chain = RunnableWithMessageHistory(llm, get_session_history)
chain = RunnableWithMessageHistory(model, get_session_history)
chain.invoke(
"Hi I'm Bob.",
config={"configurable": {"session_id": "1"}},
@@ -85,9 +85,9 @@ class ConversationChain(LLMChain):
return store[session_id]
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
model = ChatOpenAI(model="gpt-3.5-turbo-0125")
chain = RunnableWithMessageHistory(llm, get_session_history)
chain = RunnableWithMessageHistory(model, get_session_history)
chain.invoke(
"Hi I'm Bob.",
config={"configurable": {"session_id": "1"}},

View File

@@ -283,7 +283,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
retriever = ... # Your retriever
llm = ChatOpenAI()
model = ChatOpenAI()
# Contextualize question
contextualize_q_system_prompt = (
@@ -301,7 +301,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
model, retriever, contextualize_q_prompt
)
# Answer question
@@ -324,7 +324,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
# Below we use create_stuff_documents_chain to feed all retrieved context
# into the LLM. Note that we can also use StuffDocumentsChain and other
# instances of BaseCombineDocumentsChain.
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
question_answer_chain = create_stuff_documents_chain(model, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
# Usage:
@@ -371,8 +371,8 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
"Follow up question: {question}"
)
prompt = PromptTemplate.from_template(template)
llm = OpenAI()
question_generator_chain = LLMChain(llm=llm, prompt=prompt)
model = OpenAI()
question_generator_chain = LLMChain(llm=model, prompt=prompt)
chain = ConversationalRetrievalChain(
combine_docs_chain=combine_docs_chain,
retriever=retriever,

View File

@@ -38,10 +38,10 @@ def create_history_aware_retriever(
from langchain_classic import hub
rephrase_prompt = hub.pull("langchain-ai/chat-langchain-rephrase")
llm = ChatOpenAI()
model = ChatOpenAI()
retriever = ...
chat_retriever_chain = create_history_aware_retriever(
llm, retriever, rephrase_prompt
model, retriever, rephrase_prompt
)
chain.invoke({"input": "...", "chat_history": })

View File

@@ -55,8 +55,8 @@ class LLMChain(Chain):
prompt_template = "Tell me a {adjective} joke"
prompt = PromptTemplate(input_variables=["adjective"], template=prompt_template)
llm = OpenAI()
chain = prompt | llm | StrOutputParser()
model = OpenAI()
chain = prompt | model | StrOutputParser()
chain.invoke("your adjective here")
```
@@ -69,7 +69,7 @@ class LLMChain(Chain):
prompt_template = "Tell me a {adjective} joke"
prompt = PromptTemplate(input_variables=["adjective"], template=prompt_template)
llm = LLMChain(llm=OpenAI(), prompt=prompt)
model = LLMChain(llm=OpenAI(), prompt=prompt)
```
"""

View File

@@ -80,8 +80,8 @@ class LLMCheckerChain(Chain):
from langchain_community.llms import OpenAI
from langchain_classic.chains import LLMCheckerChain
llm = OpenAI(temperature=0.7)
checker_chain = LLMCheckerChain.from_llm(llm)
model = OpenAI(temperature=0.7)
checker_chain = LLMCheckerChain.from_llm(model)
```
"""

View File

@@ -84,9 +84,9 @@ class LLMMathChain(Chain):
)
)
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
tools = [calculator]
llm_with_tools = llm.bind_tools(tools, tool_choice="any")
model_with_tools = model.bind_tools(tools, tool_choice="any")
class ChainState(TypedDict):
\"\"\"LangGraph state.\"\"\"
@@ -95,11 +95,11 @@ class LLMMathChain(Chain):
async def acall_chain(state: ChainState, config: RunnableConfig):
last_message = state["messages"][-1]
response = await llm_with_tools.ainvoke(state["messages"], config)
response = await model_with_tools.ainvoke(state["messages"], config)
return {"messages": [response]}
async def acall_model(state: ChainState, config: RunnableConfig):
response = await llm.ainvoke(state["messages"], config)
response = await model.ainvoke(state["messages"], config)
return {"messages": [response]}
graph_builder = StateGraph(ChainState)

View File

@@ -83,8 +83,8 @@ class LLMSummarizationCheckerChain(Chain):
from langchain_community.llms import OpenAI
from langchain_classic.chains import LLMSummarizationCheckerChain
llm = OpenAI(temperature=0.0)
checker_chain = LLMSummarizationCheckerChain.from_llm(llm)
model = OpenAI(temperature=0.0)
checker_chain = LLMSummarizationCheckerChain.from_llm(model)
```
"""

View File

@@ -84,8 +84,8 @@ class NatBotChain(Chain):
"""Load with default LLMChain."""
msg = (
"This method is no longer implemented. Please use from_llm."
"llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)"
"For example, NatBotChain.from_llm(llm, objective)"
"model = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)"
"For example, NatBotChain.from_llm(model, objective)"
)
raise NotImplementedError(msg)

View File

@@ -107,7 +107,7 @@ def create_openai_fn_chain(
fav_food: str | None = Field(None, description="The dog's favorite food")
llm = ChatOpenAI(model="gpt-4", temperature=0)
model = ChatOpenAI(model="gpt-4", temperature=0)
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a world class algorithm for recording entities."),
@@ -115,7 +115,7 @@ def create_openai_fn_chain(
("human", "Tip: Make sure to answer in the correct format"),
]
)
chain = create_openai_fn_chain([RecordPerson, RecordDog], llm, prompt)
chain = create_openai_fn_chain([RecordPerson, RecordDog], model, prompt)
chain.run("Harry was a chubby brown beagle who loved chicken")
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
@@ -191,7 +191,7 @@ def create_structured_output_chain(
color: str = Field(..., description="The dog's color")
fav_food: str | None = Field(None, description="The dog's favorite food")
llm = ChatOpenAI(model="gpt-3.5-turbo-0613", temperature=0)
model = ChatOpenAI(model="gpt-3.5-turbo-0613", temperature=0)
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a world class algorithm for extracting information in structured formats."),
@@ -199,7 +199,7 @@ def create_structured_output_chain(
("human", "Tip: Make sure to answer in the correct format"),
]
)
chain = create_structured_output_chain(Dog, llm, prompt)
chain = create_structured_output_chain(Dog, model, prompt)
chain.run("Harry was a chubby brown beagle who loved chicken")
# -> Dog(name="Harry", color="brown", fav_food="chicken")

View File

@@ -83,12 +83,12 @@ def create_citation_fuzzy_match_runnable(llm: BaseChatModel) -> Runnable:
from langchain_classic.chains import create_citation_fuzzy_match_runnable
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini")
model = ChatOpenAI(model="gpt-4o-mini")
context = "Alice has blue eyes. Bob has brown eyes. Charlie has green eyes."
question = "What color are Bob's eyes?"
chain = create_citation_fuzzy_match_runnable(llm)
chain = create_citation_fuzzy_match_runnable(model)
chain.invoke({"question": question, "context": context})
```

View File

@@ -73,8 +73,8 @@ Passage:
# to see an up to date list of which models support
# with_structured_output.
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
structured_llm = model.with_structured_output(Joke)
structured_llm.invoke("Tell me a joke about cats.
structured_model = model.with_structured_output(Joke)
structured_model.invoke("Tell me a joke about cats.
Make sure to call the Joke function.")
"""
),
@@ -143,8 +143,8 @@ def create_extraction_chain(
# to see an up to date list of which models support
# with_structured_output.
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
structured_llm = model.with_structured_output(Joke)
structured_llm.invoke("Tell me a joke about cats.
structured_model = model.with_structured_output(Joke)
structured_model.invoke("Tell me a joke about cats.
Make sure to call the Joke function.")
"""
),
@@ -155,10 +155,10 @@ def create_extraction_chain_pydantic(
prompt: BasePromptTemplate | None = None,
verbose: bool = False, # noqa: FBT001,FBT002
) -> Chain:
"""Creates a chain that extracts information from a passage using pydantic schema.
"""Creates a chain that extracts information from a passage using Pydantic schema.
Args:
pydantic_schema: The pydantic schema of the entities to extract.
pydantic_schema: The Pydantic schema of the entities to extract.
llm: The language model to use.
prompt: The prompt to use for extraction.
verbose: Whether to run in verbose mode. In verbose mode, some intermediate

View File

@@ -330,7 +330,7 @@ def get_openapi_chain(
prompt = ChatPromptTemplate.from_template(
"Use the provided APIs to respond to this user query:\\n\\n{query}"
)
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0).bind_tools(tools)
model = ChatOpenAI(model="gpt-4o-mini", temperature=0).bind_tools(tools)
def _execute_tool(message) -> Any:
if tool_calls := message.tool_calls:
@@ -341,7 +341,7 @@ def get_openapi_chain(
else:
return message.content
chain = prompt | llm | _execute_tool
chain = prompt | model | _execute_tool
```
```python
@@ -394,7 +394,7 @@ def get_openapi_chain(
msg = (
"Must provide an LLM for this chain.For example,\n"
"from langchain_openai import ChatOpenAI\n"
"llm = ChatOpenAI()\n"
"model = ChatOpenAI()\n"
)
raise ValueError(msg)
prompt = prompt or ChatPromptTemplate.from_template(

View File

@@ -77,8 +77,8 @@ def create_tagging_chain(
# to see an up to date list of which models support
# with_structured_output.
model = ChatAnthropic(model="claude-3-haiku-20240307", temperature=0)
structured_llm = model.with_structured_output(Joke)
structured_llm.invoke(
structured_model = model.with_structured_output(Joke)
structured_model.invoke(
"Why did the cat cross the road? To get to the other "
"side... and then lay down in the middle of it!"
)
@@ -93,7 +93,7 @@ def create_tagging_chain(
kwargs: Additional keyword arguments to pass to the chain.
Returns:
Chain (LLMChain) that can be used to extract information from a passage.
Chain (`LLMChain`) that can be used to extract information from a passage.
"""
function = _get_tagging_function(schema)
@@ -130,10 +130,10 @@ def create_tagging_chain_pydantic(
prompt: ChatPromptTemplate | None = None,
**kwargs: Any,
) -> Chain:
"""Create tagging chain from pydantic schema.
"""Create tagging chain from Pydantic schema.
Create a chain that extracts information from a passage
based on a pydantic schema.
based on a Pydantic schema.
This function is deprecated. Please use `with_structured_output` instead.
See example usage below:
@@ -153,8 +153,8 @@ def create_tagging_chain_pydantic(
# to see an up to date list of which models support
# with_structured_output.
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
structured_llm = model.with_structured_output(Joke)
structured_llm.invoke(
structured_model = model.with_structured_output(Joke)
structured_model.invoke(
"Why did the cat cross the road? To get to the other "
"side... and then lay down in the middle of it!"
)
@@ -163,13 +163,13 @@ def create_tagging_chain_pydantic(
Read more here: https://python.langchain.com/docs/how_to/structured_output/
Args:
pydantic_schema: The pydantic schema of the entities to extract.
pydantic_schema: The Pydantic schema of the entities to extract.
llm: The language model to use.
prompt: The prompt template to use for the chain.
kwargs: Additional keyword arguments to pass to the chain.
Returns:
Chain (LLMChain) that can be used to extract information from a passage.
Chain (`LLMChain`) that can be used to extract information from a passage.
"""
if hasattr(pydantic_schema, "model_json_schema"):

View File

@@ -42,8 +42,8 @@ If a property is not present and is not required in the function parameters, do
# to see an up to date list of which models support
# with_structured_output.
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
structured_llm = model.with_structured_output(Joke)
structured_llm.invoke("Tell me a joke about cats.
structured_model = model.with_structured_output(Joke)
structured_model.invoke("Tell me a joke about cats.
Make sure to call the Joke function.")
"""
),

View File

@@ -48,7 +48,7 @@ def is_llm(llm: BaseLanguageModel) -> bool:
llm: Language model to check.
Returns:
True if the language model is a BaseLLM model, False otherwise.
`True` if the language model is a BaseLLM model, `False` otherwise.
"""
return isinstance(llm, BaseLLM)
@@ -60,6 +60,6 @@ def is_chat_model(llm: BaseLanguageModel) -> bool:
llm: Language model to check.
Returns:
True if the language model is a BaseChatModel model, False otherwise.
`True` if the language model is a BaseChatModel model, `False` otherwise.
"""
return isinstance(llm, BaseChatModel)

View File

@@ -34,7 +34,7 @@ class QAGenerationChain(Chain):
- Supports async and streaming;
- Surfaces prompt and text splitter for easier customization;
- Use of JsonOutputParser supports JSONPatch operations in streaming mode,
as well as robustness to markdown.
as well as robustness to markdown.
```python
from langchain_classic.chains.qa_generation.prompt import (
@@ -52,14 +52,14 @@ class QAGenerationChain(Chain):
from langchain_openai import ChatOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
llm = ChatOpenAI()
model = ChatOpenAI()
text_splitter = RecursiveCharacterTextSplitter(chunk_overlap=500)
split_text = RunnableLambda(lambda x: text_splitter.create_documents([x]))
chain = RunnableParallel(
text=RunnablePassthrough(),
questions=(
split_text | RunnableEach(bound=prompt | llm | JsonOutputParser())
split_text | RunnableEach(bound=prompt | model | JsonOutputParser())
),
)
```

View File

@@ -112,7 +112,7 @@ class QueryTransformer(Transformer):
args: The arguments passed to the function.
Returns:
FilterDirective: The filter directive.
The filter directive.
Raises:
ValueError: If the function is a comparator and the first arg is not in the

View File

@@ -46,9 +46,11 @@ def create_retrieval_chain(
from langchain_classic import hub
retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
llm = ChatOpenAI()
model = ChatOpenAI()
retriever = ...
combine_docs_chain = create_stuff_documents_chain(llm, retrieval_qa_chat_prompt)
combine_docs_chain = create_stuff_documents_chain(
model, retrieval_qa_chat_prompt
)
retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain)
retrieval_chain.invoke({"input": "..."})

View File

@@ -237,7 +237,7 @@ class RetrievalQA(BaseRetrievalQA):
retriever = ... # Your retriever
llm = ChatOpenAI()
model = ChatOpenAI()
system_prompt = (
"Use the given context to answer the question. "
@@ -251,7 +251,7 @@ class RetrievalQA(BaseRetrievalQA):
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, prompt)
question_answer_chain = create_stuff_documents_chain(model, prompt)
chain = create_retrieval_chain(retriever, question_answer_chain)
chain.invoke({"input": query})

View File

@@ -48,7 +48,7 @@ class LLMRouterChain(RouterChain):
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini")
model = ChatOpenAI(model="gpt-4o-mini")
prompt_1 = ChatPromptTemplate.from_messages(
[
@@ -63,8 +63,8 @@ class LLMRouterChain(RouterChain):
]
)
chain_1 = prompt_1 | llm | StrOutputParser()
chain_2 = prompt_2 | llm | StrOutputParser()
chain_1 = prompt_1 | model | StrOutputParser()
chain_2 = prompt_2 | model | StrOutputParser()
route_system = "Route the user's query to either the animal "
"or vegetable expert."
@@ -83,7 +83,7 @@ class LLMRouterChain(RouterChain):
route_chain = (
route_prompt
| llm.with_structured_output(RouteQuery)
| model.with_structured_output(RouteQuery)
| itemgetter("destination")
)

View File

@@ -49,7 +49,7 @@ class MultiPromptChain(MultiRouteChain):
from langgraph.graph import END, START, StateGraph
from typing_extensions import TypedDict
llm = ChatOpenAI(model="gpt-4o-mini")
model = ChatOpenAI(model="gpt-4o-mini")
# Define the prompts we will route to
prompt_1 = ChatPromptTemplate.from_messages(
@@ -68,8 +68,8 @@ class MultiPromptChain(MultiRouteChain):
# Construct the chains we will route to. These format the input query
# into the respective prompt, run it through a chat model, and cast
# the result to a string.
chain_1 = prompt_1 | llm | StrOutputParser()
chain_2 = prompt_2 | llm | StrOutputParser()
chain_1 = prompt_1 | model | StrOutputParser()
chain_2 = prompt_2 | model | StrOutputParser()
# Next: define the chain that selects which branch to route to.
@@ -92,7 +92,7 @@ class MultiPromptChain(MultiRouteChain):
destination: Literal["animal", "vegetable"]
route_chain = route_prompt | llm.with_structured_output(RouteQuery)
route_chain = route_prompt | model.with_structured_output(RouteQuery)
# For LangGraph, we will define the state of the graph to hold the query,

View File

@@ -117,7 +117,7 @@ class MultiRetrievalQAChain(MultiRouteChain):
"default LLMs on behalf of users."
"You can provide a conversation LLM like so:\n"
"from langchain_openai import ChatOpenAI\n"
"llm = ChatOpenAI()"
"model = ChatOpenAI()"
)
raise NotImplementedError(msg)
_default_chain = ConversationChain(

View File

@@ -76,8 +76,8 @@ def create_sql_query_chain(
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
chain = create_sql_query_chain(llm, db)
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
chain = create_sql_query_chain(model, db)
response = chain.invoke({"question": "How many employees are there"})
```

View File

@@ -57,8 +57,8 @@ from pydantic import BaseModel
# to see an up to date list of which models support
# with_structured_output.
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
structured_llm = model.with_structured_output(Joke)
structured_llm.invoke("Tell me a joke about cats.
structured_model = model.with_structured_output(Joke)
structured_model.invoke("Tell me a joke about cats.
Make sure to call the Joke function.")
"""
),
@@ -127,9 +127,9 @@ def create_openai_fn_runnable(
fav_food: str | None = Field(None, description="The dog's favorite food")
llm = ChatOpenAI(model="gpt-4", temperature=0)
structured_llm = create_openai_fn_runnable([RecordPerson, RecordDog], llm)
structured_llm.invoke("Harry was a chubby brown beagle who loved chicken)
model = ChatOpenAI(model="gpt-4", temperature=0)
structured_model = create_openai_fn_runnable([RecordPerson, RecordDog], model)
structured_model.invoke("Harry was a chubby brown beagle who loved chicken)
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
```
@@ -176,8 +176,8 @@ def create_openai_fn_runnable(
# to see an up to date list of which models support
# with_structured_output.
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
structured_llm = model.with_structured_output(Joke)
structured_llm.invoke("Tell me a joke about cats.
structured_model = model.with_structured_output(Joke)
structured_model.invoke("Tell me a joke about cats.
Make sure to call the Joke function.")
"""
),
@@ -250,21 +250,21 @@ def create_structured_output_runnable(
color: str = Field(..., description="The dog's color")
fav_food: str | None = Field(None, description="The dog's favorite food")
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are an extraction algorithm. Please extract every possible instance"),
('human', '{input}')
]
)
structured_llm = create_structured_output_runnable(
structured_model = create_structured_output_runnable(
RecordDog,
llm,
model,
mode="openai-tools",
enforce_function_usage=True,
return_single=True
)
structured_llm.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
structured_model.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
```
@@ -303,15 +303,15 @@ def create_structured_output_runnable(
}
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = create_structured_output_runnable(
model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_model = create_structured_output_runnable(
dog_schema,
llm,
model,
mode="openai-tools",
enforce_function_usage=True,
return_single=True
)
structured_llm.invoke("Harry was a chubby brown beagle who loved chicken")
structured_model.invoke("Harry was a chubby brown beagle who loved chicken")
# -> {'name': 'Harry', 'color': 'brown', 'fav_food': 'chicken'}
```
@@ -330,9 +330,9 @@ def create_structured_output_runnable(
color: str = Field(..., description="The dog's color")
fav_food: str | None = Field(None, description="The dog's favorite food")
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = create_structured_output_runnable(Dog, llm, mode="openai-functions")
structured_llm.invoke("Harry was a chubby brown beagle who loved chicken")
model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_model = create_structured_output_runnable(Dog, model, mode="openai-functions")
structured_model.invoke("Harry was a chubby brown beagle who loved chicken")
# -> Dog(name="Harry", color="brown", fav_food="chicken")
```
@@ -352,13 +352,13 @@ def create_structured_output_runnable(
color: str = Field(..., description="The dog's color")
fav_food: str | None = Field(None, description="The dog's favorite food")
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = create_structured_output_runnable(Dog, llm, mode="openai-functions")
model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_model = create_structured_output_runnable(Dog, model, mode="openai-functions")
system = '''Extract information about any dogs mentioned in the user input.'''
prompt = ChatPromptTemplate.from_messages(
[("system", system), ("human", "{input}"),]
)
chain = prompt | structured_llm
chain = prompt | structured_model
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
# -> Dog(name="Harry", color="brown", fav_food="chicken")
```
@@ -379,8 +379,8 @@ def create_structured_output_runnable(
color: str = Field(..., description="The dog's color")
fav_food: str | None = Field(None, description="The dog's favorite food")
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = create_structured_output_runnable(Dog, llm, mode="openai-json")
model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_model = create_structured_output_runnable(Dog, model, mode="openai-json")
system = '''You are a world class assistant for extracting information in structured JSON formats. \
Extract a valid JSON blob from the user input that matches the following JSON Schema:
@@ -389,7 +389,7 @@ def create_structured_output_runnable(
prompt = ChatPromptTemplate.from_messages(
[("system", system), ("human", "{input}"),]
)
chain = prompt | structured_llm
chain = prompt | structured_model
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
```

View File

@@ -113,10 +113,10 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
\"\"\"Very helpful answers to geography questions.\"\"\"
return f"{country}? IDK - We may never know {question}."
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent = initialize_agent(
tools=[geography_answers],
llm=llm,
llm=model,
agent=AgentType.OPENAI_FUNCTIONS,
return_intermediate_steps=True,
)
@@ -125,7 +125,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
response = agent(question)
eval_chain = TrajectoryEvalChain.from_llm(
llm=llm, agent_tools=[geography_answers], return_reasoning=True
llm=model, agent_tools=[geography_answers], return_reasoning=True
)
result = eval_chain.evaluate_agent_trajectory(
@@ -165,7 +165,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
"""Get the description of the agent tools.
Returns:
str: The description of the agent tools.
The description of the agent tools.
"""
if self.agent_tools is None:
return ""
@@ -184,10 +184,10 @@ Description: {tool.description}"""
"""Get the agent trajectory as a formatted string.
Args:
steps (Union[str, List[Tuple[AgentAction, str]]]): The agent trajectory.
steps: The agent trajectory.
Returns:
str: The formatted agent trajectory.
The formatted agent trajectory.
"""
if isinstance(steps, str):
return steps
@@ -240,7 +240,7 @@ The following is the expected answer. Use this to measure correctness:
**kwargs: Additional keyword arguments.
Returns:
TrajectoryEvalChain: The TrajectoryEvalChain object.
The `TrajectoryEvalChain` object.
"""
if not isinstance(llm, BaseChatModel):
msg = "Only chat models supported by the current trajectory eval"
@@ -259,7 +259,7 @@ The following is the expected answer. Use this to measure correctness:
"""Get the input keys for the chain.
Returns:
List[str]: The input keys.
The input keys.
"""
return ["question", "agent_trajectory", "answer", "reference"]
@@ -268,7 +268,7 @@ The following is the expected answer. Use this to measure correctness:
"""Get the output keys for the chain.
Returns:
List[str]: The output keys.
The output keys.
"""
return ["score", "reasoning"]
@@ -289,7 +289,7 @@ The following is the expected answer. Use this to measure correctness:
run_manager: The callback manager for the chain run.
Returns:
Dict[str, Any]: The output values of the chain.
The output values of the chain.
"""
chain_input = {**inputs}
if self.agent_tools:
@@ -313,7 +313,7 @@ The following is the expected answer. Use this to measure correctness:
run_manager: The callback manager for the chain run.
Returns:
Dict[str, Any]: The output values of the chain.
The output values of the chain.
"""
chain_input = {**inputs}
if self.agent_tools:

View File

@@ -165,10 +165,10 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
Example:
>>> from langchain_community.chat_models import ChatOpenAI
>>> from langchain_classic.evaluation.comparison import PairwiseStringEvalChain
>>> llm = ChatOpenAI(
>>> model = ChatOpenAI(
... temperature=0, model_name="gpt-4", model_kwargs={"random_seed": 42}
... )
>>> chain = PairwiseStringEvalChain.from_llm(llm=llm)
>>> chain = PairwiseStringEvalChain.from_llm(llm=model)
>>> result = chain.evaluate_string_pairs(
... input = "What is the chemical formula for water?",
... prediction = "H2O",
@@ -207,7 +207,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
"""Return whether the chain requires a reference.
Returns:
True if the chain requires a reference, False otherwise.
`True` if the chain requires a reference, `False` otherwise.
"""
return False
@@ -217,7 +217,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
"""Return whether the chain requires an input.
Returns:
bool: True if the chain requires an input, False otherwise.
`True` if the chain requires an input, `False` otherwise.
"""
return True
@@ -227,7 +227,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
"""Return the warning to show when reference is ignored.
Returns:
str: The warning to show when reference is ignored.
The warning to show when reference is ignored.
"""
return (
@@ -425,7 +425,7 @@ class LabeledPairwiseStringEvalChain(PairwiseStringEvalChain):
"""Return whether the chain requires a reference.
Returns:
bool: True if the chain requires a reference, False otherwise.
`True` if the chain requires a reference, `False` otherwise.
"""
return True
@@ -442,18 +442,18 @@ class LabeledPairwiseStringEvalChain(PairwiseStringEvalChain):
"""Initialize the LabeledPairwiseStringEvalChain from an LLM.
Args:
llm (BaseLanguageModel): The LLM to use.
prompt (PromptTemplate, optional): The prompt to use.
criteria (Union[CRITERIA_TYPE, str], optional): The criteria to use.
**kwargs (Any): Additional keyword arguments.
llm: The LLM to use.
prompt: The prompt to use.
criteria: The criteria to use.
**kwargs: Additional keyword arguments.
Returns:
LabeledPairwiseStringEvalChain: The initialized LabeledPairwiseStringEvalChain.
The initialized `LabeledPairwiseStringEvalChain`.
Raises:
ValueError: If the input variables are not as expected.
""" # noqa: E501
"""
expected_input_vars = {
"prediction",
"prediction_b",

View File

@@ -15,9 +15,9 @@ Using a predefined criterion:
>>> from langchain_community.llms import OpenAI
>>> from langchain_classic.evaluation.criteria import CriteriaEvalChain
>>> llm = OpenAI()
>>> model = OpenAI()
>>> criteria = "conciseness"
>>> chain = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria)
>>> chain = CriteriaEvalChain.from_llm(llm=model, criteria=criteria)
>>> chain.evaluate_strings(
prediction="The answer is 42.",
reference="42",
@@ -29,7 +29,7 @@ Using a custom criterion:
>>> from langchain_community.llms import OpenAI
>>> from langchain_classic.evaluation.criteria import LabeledCriteriaEvalChain
>>> llm = OpenAI()
>>> model = OpenAI()
>>> criteria = {
"hallucination": (
"Does this submission contain information"
@@ -37,7 +37,7 @@ Using a custom criterion:
),
}
>>> chain = LabeledCriteriaEvalChain.from_llm(
llm=llm,
llm=model,
criteria=criteria,
)
>>> chain.evaluate_strings(

View File

@@ -190,9 +190,9 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
--------
>>> from langchain_anthropic import ChatAnthropic
>>> from langchain_classic.evaluation.criteria import CriteriaEvalChain
>>> llm = ChatAnthropic(temperature=0)
>>> model = ChatAnthropic(temperature=0)
>>> criteria = {"my-custom-criterion": "Is the submission the most amazing ever?"}
>>> evaluator = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria)
>>> evaluator = CriteriaEvalChain.from_llm(llm=model, criteria=criteria)
>>> evaluator.evaluate_strings(
... prediction="Imagine an ice cream flavor for the color aquamarine",
... input="Tell me an idea",
@@ -205,10 +205,10 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
>>> from langchain_openai import ChatOpenAI
>>> from langchain_classic.evaluation.criteria import LabeledCriteriaEvalChain
>>> llm = ChatOpenAI(model="gpt-4", temperature=0)
>>> model = ChatOpenAI(model="gpt-4", temperature=0)
>>> criteria = "correctness"
>>> evaluator = LabeledCriteriaEvalChain.from_llm(
... llm=llm,
... llm=model,
... criteria=criteria,
... )
>>> evaluator.evaluate_strings(
@@ -347,7 +347,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
--------
>>> from langchain_openai import OpenAI
>>> from langchain_classic.evaluation.criteria import LabeledCriteriaEvalChain
>>> llm = OpenAI()
>>> model = OpenAI()
>>> criteria = {
"hallucination": (
"Does this submission contain information"
@@ -355,7 +355,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
),
}
>>> chain = LabeledCriteriaEvalChain.from_llm(
llm=llm,
llm=model,
criteria=criteria,
)
"""
@@ -433,9 +433,9 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
Examples:
>>> from langchain_openai import OpenAI
>>> from langchain_classic.evaluation.criteria import CriteriaEvalChain
>>> llm = OpenAI()
>>> model = OpenAI()
>>> criteria = "conciseness"
>>> chain = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria)
>>> chain = CriteriaEvalChain.from_llm(llm=model, criteria=criteria)
>>> chain.evaluate_strings(
prediction="The answer is 42.",
reference="42",
@@ -485,9 +485,9 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
Examples:
>>> from langchain_openai import OpenAI
>>> from langchain_classic.evaluation.criteria import CriteriaEvalChain
>>> llm = OpenAI()
>>> model = OpenAI()
>>> criteria = "conciseness"
>>> chain = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria)
>>> chain = CriteriaEvalChain.from_llm(llm=model, criteria=criteria)
>>> await chain.aevaluate_strings(
prediction="The answer is 42.",
reference="42",
@@ -569,7 +569,7 @@ class LabeledCriteriaEvalChain(CriteriaEvalChain):
--------
>>> from langchain_openai import OpenAI
>>> from langchain_classic.evaluation.criteria import LabeledCriteriaEvalChain
>>> llm = OpenAI()
>>> model = OpenAI()
>>> criteria = {
"hallucination": (
"Does this submission contain information"
@@ -577,7 +577,7 @@ class LabeledCriteriaEvalChain(CriteriaEvalChain):
),
}
>>> chain = LabeledCriteriaEvalChain.from_llm(
llm=llm,
llm=model,
criteria=criteria,
)
"""

View File

@@ -51,7 +51,7 @@ def _embedding_factory() -> Embeddings:
"""Create an Embeddings object.
Returns:
Embeddings: The created Embeddings object.
The created `Embeddings` object.
"""
# Here for backwards compatibility.
# Generally, we do not want to be seeing imports from langchain community
@@ -94,9 +94,8 @@ class _EmbeddingDistanceChainMixin(Chain):
"""Shared functionality for embedding distance evaluators.
Attributes:
embeddings (Embeddings): The embedding objects to vectorize the outputs.
distance_metric (EmbeddingDistance): The distance metric to use
for comparing the embeddings.
embeddings: The embedding objects to vectorize the outputs.
distance_metric: The distance metric to use for comparing the embeddings.
"""
embeddings: Embeddings = Field(default_factory=_embedding_factory)
@@ -107,10 +106,10 @@ class _EmbeddingDistanceChainMixin(Chain):
"""Validate that the TikTok library is installed.
Args:
values (Dict[str, Any]): The values to validate.
values: The values to validate.
Returns:
Dict[str, Any]: The validated values.
The validated values.
"""
embeddings = values.get("embeddings")
types_ = []
@@ -159,7 +158,7 @@ class _EmbeddingDistanceChainMixin(Chain):
"""Return the output keys of the chain.
Returns:
List[str]: The output keys.
The output keys.
"""
return ["score"]
@@ -173,10 +172,10 @@ class _EmbeddingDistanceChainMixin(Chain):
"""Get the metric function for the given metric name.
Args:
metric (EmbeddingDistance): The metric name.
metric: The metric name.
Returns:
Any: The metric function.
The metric function.
"""
metrics = {
EmbeddingDistance.COSINE: self._cosine_distance,
@@ -334,7 +333,7 @@ class _EmbeddingDistanceChainMixin(Chain):
vectors (np.ndarray): The input vectors.
Returns:
float: The computed score.
The computed score.
"""
metric = self._get_metric(self.distance_metric)
if _check_numpy() and isinstance(vectors, _import_numpy().ndarray):
@@ -362,7 +361,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
"""Return whether the chain requires a reference.
Returns:
bool: True if a reference is required, False otherwise.
True if a reference is required, `False` otherwise.
"""
return True
@@ -376,7 +375,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
"""Return the input keys of the chain.
Returns:
List[str]: The input keys.
The input keys.
"""
return ["prediction", "reference"]
@@ -393,7 +392,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
run_manager: The callback manager.
Returns:
Dict[str, Any]: The computed score.
The computed score.
"""
vectors = self.embeddings.embed_documents(
[inputs["prediction"], inputs["reference"]],
@@ -413,12 +412,11 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
"""Asynchronously compute the score for a prediction and reference.
Args:
inputs (Dict[str, Any]): The input data.
run_manager (AsyncCallbackManagerForChainRun, optional):
The callback manager.
inputs: The input data.
run_manager: The callback manager.
Returns:
Dict[str, Any]: The computed score.
The computed score.
"""
vectors = await self.embeddings.aembed_documents(
[
@@ -523,7 +521,7 @@ class PairwiseEmbeddingDistanceEvalChain(
"""Return the input keys of the chain.
Returns:
List[str]: The input keys.
The input keys.
"""
return ["prediction", "prediction_b"]
@@ -541,12 +539,11 @@ class PairwiseEmbeddingDistanceEvalChain(
"""Compute the score for two predictions.
Args:
inputs (Dict[str, Any]): The input data.
run_manager (CallbackManagerForChainRun, optional):
The callback manager.
inputs: The input data.
run_manager: The callback manager.
Returns:
Dict[str, Any]: The computed score.
The computed score.
"""
vectors = self.embeddings.embed_documents(
[
@@ -569,12 +566,11 @@ class PairwiseEmbeddingDistanceEvalChain(
"""Asynchronously compute the score for two predictions.
Args:
inputs (Dict[str, Any]): The input data.
run_manager (AsyncCallbackManagerForChainRun, optional):
The callback manager.
inputs: The input data.
run_manager: The callback manager.
Returns:
Dict[str, Any]: The computed score.
The computed score.
"""
vectors = await self.embeddings.aembed_documents(
[

View File

@@ -61,7 +61,7 @@ class ExactMatchStringEvaluator(StringEvaluator):
"""Get the input keys.
Returns:
List[str]: The input keys.
The input keys.
"""
return ["reference", "prediction"]
@@ -70,7 +70,7 @@ class ExactMatchStringEvaluator(StringEvaluator):
"""Get the evaluation name.
Returns:
str: The evaluation name.
The evaluation name.
"""
return "exact_match"

View File

@@ -71,10 +71,11 @@ class JsonValidityEvaluator(StringEvaluator):
**kwargs: Additional keyword arguments (not used).
Returns:
dict: A dictionary containing the evaluation score. The score is 1 if
the prediction is valid JSON, and 0 otherwise.
If the prediction is not valid JSON, the dictionary also contains
a "reasoning" field with the error message.
A dictionary containing the evaluation score. The score is `1` if
the prediction is valid JSON, and `0` otherwise.
If the prediction is not valid JSON, the dictionary also contains
a `reasoning` field with the error message.
"""
try:

View File

@@ -50,7 +50,7 @@ class JsonEditDistanceEvaluator(StringEvaluator):
Raises:
ImportError: If the `rapidfuzz` package is not installed and no
`string_distance` function is provided.
`string_distance` function is provided.
"""
super().__init__()
if string_distance is not None:

View File

@@ -113,17 +113,16 @@ class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
"""Load QA Eval Chain from LLM.
Args:
llm (BaseLanguageModel): the base language model to use.
llm: The base language model to use.
prompt: A prompt template containing the input_variables:
`'input'`, `'answer'` and `'result'` that will be used as the prompt
for evaluation.
prompt (PromptTemplate): A prompt template containing the input_variables:
'input', 'answer' and 'result' that will be used as the prompt
for evaluation.
Defaults to PROMPT.
**kwargs: additional keyword arguments.
Defaults to `PROMPT`.
**kwargs: Additional keyword arguments.
Returns:
QAEvalChain: the loaded QA eval chain.
The loaded QA eval chain.
"""
prompt = prompt or PROMPT
expected_input_vars = {"query", "answer", "result"}
@@ -264,17 +263,16 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
"""Load QA Eval Chain from LLM.
Args:
llm (BaseLanguageModel): the base language model to use.
llm: The base language model to use.
prompt: A prompt template containing the `input_variables`:
`'query'`, `'context'` and `'result'` that will be used as the prompt
for evaluation.
prompt (PromptTemplate): A prompt template containing the input_variables:
'query', 'context' and 'result' that will be used as the prompt
for evaluation.
Defaults to PROMPT.
**kwargs: additional keyword arguments.
Defaults to `PROMPT`.
**kwargs: Additional keyword arguments.
Returns:
ContextQAEvalChain: the loaded QA eval chain.
The loaded QA eval chain.
"""
prompt = prompt or CONTEXT_PROMPT
cls._validate_input_vars(prompt)

View File

@@ -53,7 +53,7 @@ class RegexMatchStringEvaluator(StringEvaluator):
"""Get the input keys.
Returns:
List[str]: The input keys.
The input keys.
"""
return ["reference", "prediction"]
@@ -62,7 +62,7 @@ class RegexMatchStringEvaluator(StringEvaluator):
"""Get the evaluation name.
Returns:
str: The evaluation name.
The evaluation name.
"""
return "regex_match"

View File

@@ -114,8 +114,8 @@ class _EvalArgsMixin:
"""Check if the evaluation arguments are valid.
Args:
reference (str | None, optional): The reference label.
input_ (str | None, optional): The input string.
reference: The reference label.
input_: The input string.
Raises:
ValueError: If the evaluator requires an input string but none is provided,
@@ -162,17 +162,17 @@ class StringEvaluator(_EvalArgsMixin, ABC):
"""Evaluate Chain or LLM output, based on optional input and label.
Args:
prediction (str): The LLM or chain prediction to evaluate.
reference (str | None, optional): The reference label to evaluate against.
input (str | None, optional): The input to consider during evaluation.
kwargs: Additional keyword arguments, including callbacks, tags, etc.
prediction: The LLM or chain prediction to evaluate.
reference: The reference label to evaluate against.
input: The input to consider during evaluation.
**kwargs: Additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
It is recommended that the dictionary contain the following keys:
- score: the score of the evaluation, if applicable.
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
The evaluation results containing the score or value.
It is recommended that the dictionary contain the following keys:
- score: the score of the evaluation, if applicable.
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
"""
async def _aevaluate_strings(
@@ -186,17 +186,17 @@ class StringEvaluator(_EvalArgsMixin, ABC):
"""Asynchronously evaluate Chain or LLM output, based on optional input and label.
Args:
prediction (str): The LLM or chain prediction to evaluate.
reference (str | None, optional): The reference label to evaluate against.
input (str | None, optional): The input to consider during evaluation.
kwargs: Additional keyword arguments, including callbacks, tags, etc.
prediction: The LLM or chain prediction to evaluate.
reference: The reference label to evaluate against.
input: The input to consider during evaluation.
**kwargs: Additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
It is recommended that the dictionary contain the following keys:
- score: the score of the evaluation, if applicable.
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
The evaluation results containing the score or value.
It is recommended that the dictionary contain the following keys:
- score: the score of the evaluation, if applicable.
- value: the string value of the evaluation, if applicable.
- reasoning: the reasoning for the evaluation, if applicable.
""" # noqa: E501
return await run_in_executor(
None,
@@ -218,13 +218,13 @@ class StringEvaluator(_EvalArgsMixin, ABC):
"""Evaluate Chain or LLM output, based on optional input and label.
Args:
prediction (str): The LLM or chain prediction to evaluate.
reference (str | None, optional): The reference label to evaluate against.
input (str | None, optional): The input to consider during evaluation.
kwargs: Additional keyword arguments, including callbacks, tags, etc.
prediction: The LLM or chain prediction to evaluate.
reference: The reference label to evaluate against.
input: The input to consider during evaluation.
**kwargs: Additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
The evaluation results containing the score or value.
"""
self._check_evaluation_args(reference=reference, input_=input)
return self._evaluate_strings(
@@ -245,13 +245,13 @@ class StringEvaluator(_EvalArgsMixin, ABC):
"""Asynchronously evaluate Chain or LLM output, based on optional input and label.
Args:
prediction (str): The LLM or chain prediction to evaluate.
reference (str | None, optional): The reference label to evaluate against.
input (str | None, optional): The input to consider during evaluation.
kwargs: Additional keyword arguments, including callbacks, tags, etc.
prediction: The LLM or chain prediction to evaluate.
reference: The reference label to evaluate against.
input: The input to consider during evaluation.
**kwargs: Additional keyword arguments, including callbacks, tags, etc.
Returns:
dict: The evaluation results containing the score or value.
The evaluation results containing the score or value.
""" # noqa: E501
self._check_evaluation_args(reference=reference, input_=input)
return await self._aevaluate_strings(
@@ -278,14 +278,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
"""Evaluate the output string pairs.
Args:
prediction (str): The output string from the first model.
prediction_b (str): The output string from the second model.
reference (str | None, optional): The expected output / reference string.
input (str | None, optional): The input string.
kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
prediction: The output string from the first model.
prediction_b: The output string from the second model.
reference: The expected output / reference string.
input: The input string.
**kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
Returns:
dict: A dictionary containing the preference, scores, and/or other information.
A dictionary containing the preference, scores, and/or other information.
""" # noqa: E501
async def _aevaluate_string_pairs(
@@ -300,14 +300,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
"""Asynchronously evaluate the output string pairs.
Args:
prediction (str): The output string from the first model.
prediction_b (str): The output string from the second model.
reference (str | None, optional): The expected output / reference string.
input (str | None, optional): The input string.
kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
prediction: The output string from the first model.
prediction_b: The output string from the second model.
reference: The expected output / reference string.
input: The input string.
**kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
Returns:
dict: A dictionary containing the preference, scores, and/or other information.
A dictionary containing the preference, scores, and/or other information.
""" # noqa: E501
return await run_in_executor(
None,
@@ -331,14 +331,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
"""Evaluate the output string pairs.
Args:
prediction (str): The output string from the first model.
prediction_b (str): The output string from the second model.
reference (str | None, optional): The expected output / reference string.
input (str | None, optional): The input string.
kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
prediction: The output string from the first model.
prediction_b: The output string from the second model.
reference: The expected output / reference string.
input: The input string.
**kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
Returns:
dict: A dictionary containing the preference, scores, and/or other information.
A dictionary containing the preference, scores, and/or other information.
""" # noqa: E501
self._check_evaluation_args(reference=reference, input_=input)
return self._evaluate_string_pairs(
@@ -361,14 +361,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
"""Asynchronously evaluate the output string pairs.
Args:
prediction (str): The output string from the first model.
prediction_b (str): The output string from the second model.
reference (str | None, optional): The expected output / reference string.
input (str | None, optional): The input string.
kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
prediction: The output string from the first model.
prediction_b: The output string from the second model.
reference: The expected output / reference string.
input: The input string.
**kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
Returns:
dict: A dictionary containing the preference, scores, and/or other information.
A dictionary containing the preference, scores, and/or other information.
""" # noqa: E501
self._check_evaluation_args(reference=reference, input_=input)
return await self._aevaluate_string_pairs(

View File

@@ -7,8 +7,8 @@ criteria and or a reference answer.
Example:
>>> from langchain_community.chat_models import ChatOpenAI
>>> from langchain_classic.evaluation.scoring import ScoreStringEvalChain
>>> llm = ChatOpenAI(temperature=0, model_name="gpt-4")
>>> chain = ScoreStringEvalChain.from_llm(llm=llm)
>>> model = ChatOpenAI(temperature=0, model_name="gpt-4")
>>> chain = ScoreStringEvalChain.from_llm(llm=model)
>>> result = chain.evaluate_strings(
... input="What is the chemical formula for water?",
... prediction="H2O",

View File

@@ -56,10 +56,10 @@ def resolve_criteria(
"""Resolve the criteria for the pairwise evaluator.
Args:
criteria (Union[CRITERIA_TYPE, str], optional): The criteria to use.
criteria: The criteria to use.
Returns:
dict: The resolved criteria.
The resolved criteria.
"""
if criteria is None:
@@ -156,8 +156,8 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
Example:
>>> from langchain_community.chat_models import ChatOpenAI
>>> from langchain_classic.evaluation.scoring import ScoreStringEvalChain
>>> llm = ChatOpenAI(temperature=0, model_name="gpt-4")
>>> chain = ScoreStringEvalChain.from_llm(llm=llm)
>>> model = ChatOpenAI(temperature=0, model_name="gpt-4")
>>> chain = ScoreStringEvalChain.from_llm(llm=model)
>>> result = chain.evaluate_strings(
... input="What is the chemical formula for water?",
... prediction="H2O",
@@ -196,7 +196,7 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
"""Return whether the chain requires a reference.
Returns:
bool: True if the chain requires a reference, False otherwise.
`True` if the chain requires a reference, `False` otherwise.
"""
return False
@@ -206,7 +206,7 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
"""Return whether the chain requires an input.
Returns:
bool: True if the chain requires an input, False otherwise.
`True` if the chain requires an input, `False` otherwise.
"""
return True
@@ -227,7 +227,7 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
"""Return the warning to show when reference is ignored.
Returns:
str: The warning to show when reference is ignored.
The warning to show when reference is ignored.
"""
return (
@@ -424,7 +424,7 @@ class LabeledScoreStringEvalChain(ScoreStringEvalChain):
"""Return whether the chain requires a reference.
Returns:
bool: True if the chain requires a reference, False otherwise.
`True` if the chain requires a reference, `False` otherwise.
"""
return True
@@ -442,14 +442,14 @@ class LabeledScoreStringEvalChain(ScoreStringEvalChain):
"""Initialize the LabeledScoreStringEvalChain from an LLM.
Args:
llm (BaseLanguageModel): The LLM to use.
prompt (PromptTemplate, optional): The prompt to use.
criteria (Union[CRITERIA_TYPE, str], optional): The criteria to use.
normalize_by (float, optional): The value to normalize the score by.
**kwargs (Any): Additional keyword arguments.
llm: The LLM to use.
prompt: The prompt to use.
criteria: The criteria to use.
normalize_by: The value to normalize the score by.
**kwargs: Additional keyword arguments.
Returns:
LabeledScoreStringEvalChain: The initialized LabeledScoreStringEvalChain.
The initialized LabeledScoreStringEvalChain.
Raises:
ValueError: If the input variables are not as expected.

View File

@@ -22,10 +22,10 @@ def _load_rapidfuzz() -> Any:
"""Load the RapidFuzz library.
Raises:
ImportError: If the rapidfuzz library is not installed.
`ImportError`: If the rapidfuzz library is not installed.
Returns:
Any: The rapidfuzz.distance module.
The `rapidfuzz.distance` module.
"""
try:
import rapidfuzz
@@ -42,12 +42,12 @@ class StringDistance(str, Enum):
"""Distance metric to use.
Attributes:
DAMERAU_LEVENSHTEIN: The Damerau-Levenshtein distance.
LEVENSHTEIN: The Levenshtein distance.
JARO: The Jaro distance.
JARO_WINKLER: The Jaro-Winkler distance.
HAMMING: The Hamming distance.
INDEL: The Indel distance.
`DAMERAU_LEVENSHTEIN`: The Damerau-Levenshtein distance.
`LEVENSHTEIN`: The Levenshtein distance.
`JARO`: The Jaro distance.
`JARO_WINKLER`: The Jaro-Winkler distance.
`HAMMING`: The Hamming distance.
`INDEL`: The Indel distance.
"""
DAMERAU_LEVENSHTEIN = "damerau_levenshtein"
@@ -63,7 +63,7 @@ class _RapidFuzzChainMixin(Chain):
distance: StringDistance = Field(default=StringDistance.JARO_WINKLER)
normalize_score: bool = Field(default=True)
"""Whether to normalize the score to a value between 0 and 1.
"""Whether to normalize the score to a value between `0` and `1`.
Applies only to the Levenshtein and Damerau-Levenshtein distances."""
@pre_init
@@ -71,10 +71,10 @@ class _RapidFuzzChainMixin(Chain):
"""Validate that the rapidfuzz library is installed.
Args:
values (Dict[str, Any]): The input values.
values: The input values.
Returns:
Dict[str, Any]: The validated values.
The validated values.
"""
_load_rapidfuzz()
return values
@@ -84,7 +84,7 @@ class _RapidFuzzChainMixin(Chain):
"""Get the output keys.
Returns:
List[str]: The output keys.
The output keys.
"""
return ["score"]
@@ -92,10 +92,10 @@ class _RapidFuzzChainMixin(Chain):
"""Prepare the output dictionary.
Args:
result (Dict[str, Any]): The evaluation results.
result: The evaluation results.
Returns:
Dict[str, Any]: The prepared output dictionary.
The prepared output dictionary.
"""
result = {"score": result["score"]}
if RUN_KEY in result:
@@ -111,7 +111,7 @@ class _RapidFuzzChainMixin(Chain):
normalize_score: Whether to normalize the score.
Returns:
Callable: The distance metric function.
The distance metric function.
Raises:
ValueError: If the distance metric is invalid.
@@ -142,7 +142,7 @@ class _RapidFuzzChainMixin(Chain):
"""Get the distance metric function.
Returns:
Callable: The distance metric function.
The distance metric function.
"""
return _RapidFuzzChainMixin._get_metric(
self.distance,
@@ -199,7 +199,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
"""Get the input keys.
Returns:
List[str]: The input keys.
The input keys.
"""
return ["reference", "prediction"]
@@ -208,7 +208,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
"""Get the evaluation name.
Returns:
str: The evaluation name.
The evaluation name.
"""
return f"{self.distance.value}_distance"
@@ -330,7 +330,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
"""Get the input keys.
Returns:
List[str]: The input keys.
The input keys.
"""
return ["prediction", "prediction_b"]
@@ -339,7 +339,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
"""Get the evaluation name.
Returns:
str: The evaluation name.
The evaluation name.
"""
return f"pairwise_{self.distance.value}_distance"
@@ -352,12 +352,11 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
"""Compute the string distance between two predictions.
Args:
inputs (Dict[str, Any]): The input values.
run_manager (CallbackManagerForChainRun , optional):
The callback manager.
inputs: The input values.
run_manager: The callback manager.
Returns:
Dict[str, Any]: The evaluation results containing the score.
The evaluation results containing the score.
"""
return {
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
@@ -372,12 +371,11 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
"""Asynchronously compute the string distance between two predictions.
Args:
inputs (Dict[str, Any]): The input values.
run_manager (AsyncCallbackManagerForChainRun , optional):
The callback manager.
inputs: The input values.
run_manager: The callback manager.
Returns:
Dict[str, Any]: The evaluation results containing the score.
The evaluation results containing the score.
"""
return {
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),

View File

@@ -55,7 +55,7 @@ class VectorStoreIndexWrapper(BaseModel):
"Please provide an llm to use for querying the vectorstore.\n"
"For example,\n"
"from langchain_openai import OpenAI\n"
"llm = OpenAI(temperature=0)"
"model = OpenAI(temperature=0)"
)
raise NotImplementedError(msg)
retriever_kwargs = retriever_kwargs or {}
@@ -90,7 +90,7 @@ class VectorStoreIndexWrapper(BaseModel):
"Please provide an llm to use for querying the vectorstore.\n"
"For example,\n"
"from langchain_openai import OpenAI\n"
"llm = OpenAI(temperature=0)"
"model = OpenAI(temperature=0)"
)
raise NotImplementedError(msg)
retriever_kwargs = retriever_kwargs or {}
@@ -125,7 +125,7 @@ class VectorStoreIndexWrapper(BaseModel):
"Please provide an llm to use for querying the vectorstore.\n"
"For example,\n"
"from langchain_openai import OpenAI\n"
"llm = OpenAI(temperature=0)"
"model = OpenAI(temperature=0)"
)
raise NotImplementedError(msg)
retriever_kwargs = retriever_kwargs or {}
@@ -160,7 +160,7 @@ class VectorStoreIndexWrapper(BaseModel):
"Please provide an llm to use for querying the vectorstore.\n"
"For example,\n"
"from langchain_openai import OpenAI\n"
"llm = OpenAI(temperature=0)"
"model = OpenAI(temperature=0)"
)
raise NotImplementedError(msg)
retriever_kwargs = retriever_kwargs or {}

View File

@@ -21,15 +21,15 @@ class ModelLaboratory:
Args:
chains: A sequence of chains to experiment with.
Each chain must have exactly one input and one output variable.
names (list[str] | None): Optional list of names corresponding to each
chain. If provided, its length must match the number of chains.
names: Optional list of names corresponding to each chain.
If provided, its length must match the number of chains.
Raises:
ValueError: If any chain is not an instance of `Chain`.
ValueError: If a chain does not have exactly one input variable.
ValueError: If a chain does not have exactly one output variable.
ValueError: If the length of `names` does not match the number of chains.
`ValueError`: If any chain is not an instance of `Chain`.
`ValueError`: If a chain does not have exactly one input variable.
`ValueError`: If a chain does not have exactly one output variable.
`ValueError`: If the length of `names` does not match the number of chains.
"""
for chain in chains:
if not isinstance(chain, Chain):
@@ -72,7 +72,7 @@ class ModelLaboratory:
If provided, the prompt must contain exactly one input variable.
Returns:
ModelLaboratory: An instance of `ModelLaboratory` initialized with LLMs.
An instance of `ModelLaboratory` initialized with LLMs.
"""
if prompt is None:
prompt = PromptTemplate(input_variables=["_input"], template="{_input}")

View File

@@ -96,7 +96,7 @@ class StructuredOutputParser(BaseOutputParser[dict[str, Any]]):
# ```
Args:
only_json (bool): If `True`, only the json in the Markdown code snippet
only_json: If `True`, only the json in the Markdown code snippet
will be returned, without the introducing text. Defaults to `False`.
"""
schema_str = "\n".join(

View File

@@ -16,10 +16,10 @@ T = TypeVar("T", bound=BaseModel)
class YamlOutputParser(BaseOutputParser[T]):
"""Parse YAML output using a pydantic model."""
"""Parse YAML output using a Pydantic model."""
pydantic_object: type[T]
"""The pydantic model to parse."""
"""The Pydantic model to parse."""
pattern: re.Pattern = re.compile(
r"^```(?:ya?ml)?(?P<yaml>[^`]*)",
re.MULTILINE | re.DOTALL,

View File

@@ -299,8 +299,8 @@ class EnsembleRetriever(BaseRetriever):
doc_lists: A list of rank lists, where each rank list contains unique items.
Returns:
list: The final aggregated list of items sorted by their weighted RRF
scores in descending order.
The final aggregated list of items sorted by their weighted RRF
scores in descending order.
"""
if len(doc_lists) != len(self.weights):
msg = "Number of rank lists must be equal to the number of weights."

View File

@@ -22,8 +22,8 @@ from langchain_classic.smith import RunEvalConfig, run_on_dataset
# Chains may have memory. Passing in a constructor function lets the
# evaluation framework avoid cross-contamination between runs.
def construct_chain():
llm = ChatOpenAI(temperature=0)
chain = LLMChain.from_string(llm, "What's the answer to {your_input_key}")
model = ChatOpenAI(temperature=0)
chain = LLMChain.from_string(model, "What's the answer to {your_input_key}")
return chain

View File

@@ -16,8 +16,8 @@ from langchain_classic.smith import EvaluatorType, RunEvalConfig, run_on_dataset
def construct_chain():
llm = ChatOpenAI(temperature=0)
chain = LLMChain.from_string(llm, "What's the answer to {your_input_key}")
model = ChatOpenAI(temperature=0)
chain = LLMChain.from_string(model, "What's the answer to {your_input_key}")
return chain

View File

@@ -982,8 +982,7 @@ def _run_llm_or_chain(
input_mapper: Optional function to map the input to the expected format.
Returns:
Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
The outputs of the model or chain.
The outputs of the model or chain.
"""
chain_or_llm = (
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
@@ -1396,9 +1395,9 @@ async def arun_on_dataset(
# Chains may have memory. Passing in a constructor function lets the
# evaluation framework avoid cross-contamination between runs.
def construct_chain():
llm = ChatOpenAI(temperature=0)
model = ChatOpenAI(temperature=0)
chain = LLMChain.from_string(
llm,
model,
"What's the answer to {your_input_key}"
)
return chain
@@ -1571,9 +1570,9 @@ def run_on_dataset(
# Chains may have memory. Passing in a constructor function lets the
# evaluation framework avoid cross-contamination between runs.
def construct_chain():
llm = ChatOpenAI(temperature=0)
model = ChatOpenAI(temperature=0)
chain = LLMChain.from_string(
llm,
model,
"What's the answer to {your_input_key}"
)
return chain

View File

@@ -47,14 +47,15 @@ class LocalFileStore(ByteStore):
"""Implement the BaseStore interface for the local file system.
Args:
root_path (Union[str, Path]): The root path of the file store. All keys are
interpreted as paths relative to this root.
chmod_file: (optional, defaults to `None`) If specified, sets permissions
for newly created files, overriding the current `umask` if needed.
chmod_dir: (optional, defaults to `None`) If specified, sets permissions
for newly created dirs, overriding the current `umask` if needed.
update_atime: (optional, defaults to `False`) If `True`, updates the
filesystem access time (but not the modified time) when a file is read.
root_path: The root path of the file store. All keys are interpreted as
paths relative to this root.
chmod_file: If specified, sets permissions for newly created files,
overriding the current `umask` if needed.
chmod_dir: If specified, sets permissions for newly created dirs,
overriding the current `umask` if needed.
update_atime: If `True`, updates the filesystem access time
(but not the modified time) when a file is read.
This allows MRU/LRU cache policies to be implemented for filesystems
where access time updates are disabled.
"""
@@ -67,10 +68,10 @@ class LocalFileStore(ByteStore):
"""Get the full path for a given key relative to the root path.
Args:
key (str): The key relative to the root path.
key: The key relative to the root path.
Returns:
Path: The full path for the given key.
The full path for the given key.
"""
if not re.match(r"^[a-zA-Z0-9_.\-/]+$", key):
msg = f"Invalid characters in key: {key}"
@@ -148,10 +149,8 @@ class LocalFileStore(ByteStore):
"""Delete the given keys and their associated values.
Args:
keys (Sequence[str]): A sequence of keys to delete.
keys: A sequence of keys to delete.
Returns:
None
"""
for key in keys:
full_path = self._get_full_path(key)
@@ -162,10 +161,10 @@ class LocalFileStore(ByteStore):
"""Get an iterator over keys that match the given prefix.
Args:
prefix (str | None): The prefix to match.
prefix: The prefix to match.
Returns:
Iterator[str]: An iterator over keys that match the given prefix.
Yields:
Keys that match the given prefix.
"""
prefix_path = self._get_full_path(prefix) if prefix else self.root_path
for file in prefix_path.rglob("*"):

View File

@@ -27,10 +27,10 @@ def __getattr__(name: str) -> Any:
"""Dynamically retrieve attributes from the updated module path.
Args:
name (str): The name of the attribute to import.
name: The name of the attribute to import.
Returns:
Any: The resolved attribute from the updated path.
The resolved attribute from the updated path.
"""
return _import_attribute(name)

View File

@@ -35,10 +35,10 @@ def __getattr__(name: str) -> Any:
at runtime and forward them to their new locations.
Args:
name (str): The name of the attribute to import.
name: The name of the attribute to import.
Returns:
Any: The resolved attribute from the appropriate updated module.
The resolved attribute from the appropriate updated module.
"""
return _import_attribute(name)

View File

@@ -33,10 +33,10 @@ def __getattr__(name: str) -> Any:
at runtime and forward them to their new locations.
Args:
name (str): The name of the attribute to import.
name: The name of the attribute to import.
Returns:
Any: The resolved attribute from the appropriate updated module.
The resolved attribute from the appropriate updated module.
"""
return _import_attribute(name)

View File

@@ -108,15 +108,15 @@ def test_no_more_changes_to_proxy_community() -> None:
def extract_deprecated_lookup(file_path: str) -> dict[str, Any] | None:
"""Detect and extracts the value of a dictionary named DEPRECATED_LOOKUP.
"""Detect and extracts the value of a dictionary named `DEPRECATED_LOOKUP`.
This variable is located in the global namespace of a Python file.
Args:
file_path (str): The path to the Python file.
file_path: The path to the Python file.
Returns:
dict or None: The value of DEPRECATED_LOOKUP if it exists, None otherwise.
The value of `DEPRECATED_LOOKUP` if it exists, `None` otherwise.
"""
tree = ast.parse(Path(file_path).read_text(encoding="utf-8"), filename=file_path)
@@ -136,10 +136,10 @@ def _dict_from_ast(node: ast.Dict) -> dict[str, str]:
"""Convert an AST dict node to a Python dictionary, assuming str to str format.
Args:
node (ast.Dict): The AST node representing a dictionary.
node: The AST node representing a dictionary.
Returns:
dict: The corresponding Python dictionary.
The corresponding Python dictionary.
"""
result: dict[str, str] = {}
for key, value in zip(node.keys, node.values, strict=False):
@@ -153,10 +153,10 @@ def _literal_eval_str(node: ast.AST) -> str:
"""Evaluate an AST literal node to its corresponding string value.
Args:
node (ast.AST): The AST node representing a literal value.
node: The AST node representing a literal value.
Returns:
str: The corresponding string value.
The corresponding string value.
"""
if isinstance(node, ast.Constant) and isinstance(node.value, str):
return node.value

View File

@@ -13,9 +13,6 @@ from typing import (
get_type_hints,
)
if TYPE_CHECKING:
from collections.abc import Awaitable
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
from langchain_core.tools import BaseTool
@@ -47,11 +44,10 @@ from langchain.agents.structured_output import (
ToolStrategy,
)
from langchain.chat_models import init_chat_model
from langchain.tools import ToolNode
from langchain.tools.tool_node import ToolCallWithContext
from langchain.tools.tool_node import ToolCallWithContext, _ToolNode
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from collections.abc import Awaitable, Callable, Sequence
from langchain_core.runnables import Runnable
from langgraph.cache.base import BaseCache
@@ -449,6 +445,70 @@ def _chain_tool_call_wrappers(
return result
def _chain_async_tool_call_wrappers(
wrappers: Sequence[
Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
Awaitable[ToolMessage | Command],
]
],
) -> (
Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
Awaitable[ToolMessage | Command],
]
| None
):
"""Compose async wrappers into middleware stack (first = outermost).
Args:
wrappers: Async wrappers in middleware order.
Returns:
Composed async wrapper, or None if empty.
"""
if not wrappers:
return None
if len(wrappers) == 1:
return wrappers[0]
def compose_two(
outer: Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
Awaitable[ToolMessage | Command],
],
inner: Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
Awaitable[ToolMessage | Command],
],
) -> Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
Awaitable[ToolMessage | Command],
]:
"""Compose two async wrappers where outer wraps inner."""
async def composed(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
# Create an async callable that invokes inner with the original execute
async def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
return await inner(req, execute)
# Outer can call call_inner multiple times
return await outer(request, call_inner)
return composed
# Chain all wrappers: first -> second -> ... -> last
result = wrappers[-1]
for wrapper in reversed(wrappers[:-1]):
result = compose_two(wrapper, result)
return result
def create_agent( # noqa: PLR0915
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
@@ -576,9 +636,14 @@ def create_agent( # noqa: PLR0915
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
# Collect middleware with wrap_tool_call hooks
# Collect middleware with wrap_tool_call or awrap_tool_call hooks
# Include middleware with either implementation to ensure NotImplementedError is raised
# when middleware doesn't support the execution path
middleware_w_wrap_tool_call = [
m for m in middleware if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
m
for m in middleware
if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
]
# Chain all wrap_tool_call handlers into a single composed handler
@@ -587,8 +652,24 @@ def create_agent( # noqa: PLR0915
wrappers = [m.wrap_tool_call for m in middleware_w_wrap_tool_call]
wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
# Collect middleware with awrap_tool_call or wrap_tool_call hooks
# Include middleware with either implementation to ensure NotImplementedError is raised
# when middleware doesn't support the execution path
middleware_w_awrap_tool_call = [
m
for m in middleware
if m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
or m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
]
# Chain all awrap_tool_call handlers into a single composed async handler
awrap_tool_call_wrapper = None
if middleware_w_awrap_tool_call:
async_wrappers = [m.awrap_tool_call for m in middleware_w_awrap_tool_call]
awrap_tool_call_wrapper = _chain_async_tool_call_wrappers(async_wrappers)
# Setup tools
tool_node: ToolNode | None = None
tool_node: _ToolNode | None = None
# Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
built_in_tools = [t for t in tools if isinstance(t, dict)]
regular_tools = [t for t in tools if not isinstance(t, dict)]
@@ -598,7 +679,11 @@ def create_agent( # noqa: PLR0915
# Only create ToolNode if we have client-side tools
tool_node = (
ToolNode(tools=available_tools, wrap_tool_call=wrap_tool_call_wrapper)
_ToolNode(
tools=available_tools,
wrap_tool_call=wrap_tool_call_wrapper,
awrap_tool_call=awrap_tool_call_wrapper,
)
if available_tools
else None
)
@@ -640,13 +725,23 @@ def create_agent( # noqa: PLR0915
if m.__class__.after_agent is not AgentMiddleware.after_agent
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
]
# Collect middleware with wrap_model_call or awrap_model_call hooks
# Include middleware with either implementation to ensure NotImplementedError is raised
# when middleware doesn't support the execution path
middleware_w_wrap_model_call = [
m for m in middleware if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
m
for m in middleware
if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
or m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
]
# Collect middleware with awrap_model_call or wrap_model_call hooks
# Include middleware with either implementation to ensure NotImplementedError is raised
# when middleware doesn't support the execution path
middleware_w_awrap_model_call = [
m
for m in middleware
if m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
or m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
]
# Compose wrap_model_call handlers into a single middleware stack (sync)
@@ -1378,7 +1473,7 @@ def _make_model_to_model_edge(
def _make_tools_to_model_edge(
*,
tool_node: ToolNode,
tool_node: _ToolNode,
model_destination: str,
structured_output_tools: dict[str, OutputToolBinding],
end_destination: str,

View File

@@ -20,7 +20,11 @@ from .tool_selection import LLMToolSelectorMiddleware
from .types import (
AgentMiddleware,
AgentState,
ModelCallHandler,
ModelCallResult,
ModelCallWrapper,
ModelRequest,
ModelResponse,
after_agent,
after_model,
before_agent,
@@ -28,6 +32,7 @@ from .types import (
dynamic_prompt,
hook_config,
wrap_model_call,
wrap_tool_call,
)
__all__ = [
@@ -41,9 +46,13 @@ __all__ = [
"InterruptOnConfig",
"LLMToolEmulator",
"LLMToolSelectorMiddleware",
"ModelCallHandler",
"ModelCallLimitMiddleware",
"ModelCallResult",
"ModelCallWrapper",
"ModelFallbackMiddleware",
"ModelRequest",
"ModelResponse",
"PIIDetectionError",
"PIIMiddleware",
"PlanningMiddleware",
@@ -56,4 +65,5 @@ __all__ = [
"dynamic_prompt",
"hook_config",
"wrap_model_call",
"wrap_tool_call",
]

Some files were not shown because too many files have changed in this diff Show More