mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-06 17:20:16 +00:00
Compare commits
20 Commits
langchain=
...
sr/fixing-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
89d10ca1a9 | ||
|
|
760fc3bc12 | ||
|
|
e3fc7d8aa6 | ||
|
|
2b3b209e40 | ||
|
|
78903ac285 | ||
|
|
f361acc11c | ||
|
|
ed185c0026 | ||
|
|
6dc34beb71 | ||
|
|
c2205f88e6 | ||
|
|
abdbe185c5 | ||
|
|
c1b816cb7e | ||
|
|
0559558715 | ||
|
|
75965474fc | ||
|
|
5dc014fdf4 | ||
|
|
291a9fcea1 | ||
|
|
dd994b9d7f | ||
|
|
83901b30e3 | ||
|
|
bcfa21a6e7 | ||
|
|
af1da28459 | ||
|
|
ed2ee4e8cc |
23
.github/workflows/integration_tests.yml
vendored
23
.github/workflows/integration_tests.yml
vendored
@@ -23,10 +23,8 @@ permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.8.4"
|
||||
UV_FROZEN: "true"
|
||||
DEFAULT_LIBS: '["libs/partners/openai", "libs/partners/anthropic", "libs/partners/fireworks", "libs/partners/groq", "libs/partners/mistralai", "libs/partners/xai", "libs/partners/google-vertexai", "libs/partners/google-genai", "libs/partners/aws"]'
|
||||
POETRY_LIBS: ("libs/partners/aws")
|
||||
|
||||
jobs:
|
||||
# Generate dynamic test matrix based on input parameters or defaults
|
||||
@@ -60,7 +58,6 @@ jobs:
|
||||
echo $matrix
|
||||
echo "matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
# Run integration tests against partner libraries with live API credentials
|
||||
# Tests are run with Poetry or UV depending on the library's setup
|
||||
build:
|
||||
if: github.repository_owner == 'langchain-ai' || github.event_name != 'schedule'
|
||||
name: "🐍 Python ${{ matrix.python-version }}: ${{ matrix.working-directory }}"
|
||||
@@ -95,17 +92,7 @@ jobs:
|
||||
mv langchain-google/libs/vertexai langchain/libs/partners/google-vertexai
|
||||
mv langchain-aws/libs/aws langchain/libs/partners/aws
|
||||
|
||||
- name: "🐍 Set up Python ${{ matrix.python-version }} + Poetry"
|
||||
if: contains(env.POETRY_LIBS, matrix.working-directory)
|
||||
uses: "./langchain/.github/actions/poetry_setup"
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
working-directory: langchain/${{ matrix.working-directory }}
|
||||
cache-key: scheduled
|
||||
|
||||
- name: "🐍 Set up Python ${{ matrix.python-version }} + UV"
|
||||
if: "!contains(env.POETRY_LIBS, matrix.working-directory)"
|
||||
uses: "./langchain/.github/actions/uv_setup"
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -123,15 +110,7 @@ jobs:
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ secrets.AWS_REGION }}
|
||||
|
||||
- name: "📦 Install Dependencies (Poetry)"
|
||||
if: contains(env.POETRY_LIBS, matrix.working-directory)
|
||||
run: |
|
||||
echo "Running scheduled tests, installing dependencies with poetry..."
|
||||
cd langchain/${{ matrix.working-directory }}
|
||||
poetry install --with=test_integration,test
|
||||
|
||||
- name: "📦 Install Dependencies (UV)"
|
||||
if: "!contains(env.POETRY_LIBS, matrix.working-directory)"
|
||||
- name: "📦 Install Dependencies"
|
||||
run: |
|
||||
echo "Running scheduled tests, installing dependencies with uv..."
|
||||
cd langchain/${{ matrix.working-directory }}
|
||||
|
||||
@@ -152,11 +152,11 @@ def send_email(to: str, msg: str, *, priority: str = "normal") -> bool:
|
||||
priority: Email priority level (`'low'`, `'normal'`, `'high'`).
|
||||
|
||||
Returns:
|
||||
True if email was sent successfully, False otherwise.
|
||||
`True` if email was sent successfully, `False` otherwise.
|
||||
|
||||
Raises:
|
||||
InvalidEmailError: If the email address format is invalid.
|
||||
SMTPConnectionError: If unable to connect to email server.
|
||||
`InvalidEmailError`: If the email address format is invalid.
|
||||
`SMTPConnectionError`: If unable to connect to email server.
|
||||
"""
|
||||
```
|
||||
|
||||
|
||||
@@ -152,11 +152,11 @@ def send_email(to: str, msg: str, *, priority: str = "normal") -> bool:
|
||||
priority: Email priority level (`'low'`, `'normal'`, `'high'`).
|
||||
|
||||
Returns:
|
||||
True if email was sent successfully, False otherwise.
|
||||
`True` if email was sent successfully, `False` otherwise.
|
||||
|
||||
Raises:
|
||||
InvalidEmailError: If the email address format is invalid.
|
||||
SMTPConnectionError: If unable to connect to email server.
|
||||
`InvalidEmailError`: If the email address format is invalid.
|
||||
`SMTPConnectionError`: If unable to connect to email server.
|
||||
"""
|
||||
```
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@ And you should configure credentials by setting the following environment variab
|
||||
```python
|
||||
from __module_name__ import Chat__ModuleName__
|
||||
|
||||
llm = Chat__ModuleName__()
|
||||
llm.invoke("Sing a ballad of LangChain.")
|
||||
model = Chat__ModuleName__()
|
||||
model.invoke("Sing a ballad of LangChain.")
|
||||
```
|
||||
|
||||
## Embeddings
|
||||
@@ -41,6 +41,6 @@ embeddings.embed_query("What is the meaning of life?")
|
||||
```python
|
||||
from __module_name__ import __ModuleName__LLM
|
||||
|
||||
llm = __ModuleName__LLM()
|
||||
llm.invoke("The meaning of life is")
|
||||
model = __ModuleName__LLM()
|
||||
model.invoke("The meaning of life is")
|
||||
```
|
||||
|
||||
@@ -72,7 +72,9 @@
|
||||
"cell_type": "markdown",
|
||||
"id": "72ee0c4b-9764-423a-9dbf-95129e185210",
|
||||
"metadata": {},
|
||||
"source": "To enable automated tracing of your model calls, set your [LangSmith](https://docs.smith.langchain.com/) API key:"
|
||||
"source": [
|
||||
"To enable automated tracing of your model calls, set your [LangSmith](https://docs.smith.langchain.com/) API key:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -126,7 +128,7 @@
|
||||
"source": [
|
||||
"from __module_name__ import Chat__ModuleName__\n",
|
||||
"\n",
|
||||
"llm = Chat__ModuleName__(\n",
|
||||
"model = Chat__ModuleName__(\n",
|
||||
" model=\"model-name\",\n",
|
||||
" temperature=0,\n",
|
||||
" max_tokens=None,\n",
|
||||
@@ -162,7 +164,7 @@
|
||||
" ),\n",
|
||||
" (\"human\", \"I love programming.\"),\n",
|
||||
"]\n",
|
||||
"ai_msg = llm.invoke(messages)\n",
|
||||
"ai_msg = model.invoke(messages)\n",
|
||||
"ai_msg"
|
||||
]
|
||||
},
|
||||
@@ -207,7 +209,7 @@
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = prompt | llm\n",
|
||||
"chain = prompt | model\n",
|
||||
"chain.invoke(\n",
|
||||
" {\n",
|
||||
" \"input_language\": \"English\",\n",
|
||||
|
||||
@@ -65,7 +65,9 @@
|
||||
"cell_type": "markdown",
|
||||
"id": "4b6e1ca6",
|
||||
"metadata": {},
|
||||
"source": "To enable automated tracing of your model calls, set your [LangSmith](https://docs.smith.langchain.com/) API key:"
|
||||
"source": [
|
||||
"To enable automated tracing of your model calls, set your [LangSmith](https://docs.smith.langchain.com/) API key:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -119,7 +121,7 @@
|
||||
"source": [
|
||||
"from __module_name__ import __ModuleName__LLM\n",
|
||||
"\n",
|
||||
"llm = __ModuleName__LLM(\n",
|
||||
"model = __ModuleName__LLM(\n",
|
||||
" model=\"model-name\",\n",
|
||||
" temperature=0,\n",
|
||||
" max_tokens=None,\n",
|
||||
@@ -141,7 +143,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"id": "035dea0f",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -150,7 +152,7 @@
|
||||
"source": [
|
||||
"input_text = \"__ModuleName__ is an AI company that \"\n",
|
||||
"\n",
|
||||
"completion = llm.invoke(input_text)\n",
|
||||
"completion = model.invoke(input_text)\n",
|
||||
"completion"
|
||||
]
|
||||
},
|
||||
@@ -177,7 +179,7 @@
|
||||
"\n",
|
||||
"prompt = PromptTemplate(\"How to say {input} in {output_language}:\\n\")\n",
|
||||
"\n",
|
||||
"chain = prompt | llm\n",
|
||||
"chain = prompt | model\n",
|
||||
"chain.invoke(\n",
|
||||
" {\n",
|
||||
" \"output_language\": \"German\",\n",
|
||||
|
||||
@@ -155,7 +155,7 @@
|
||||
"\n",
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)"
|
||||
"model = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -185,7 +185,7 @@
|
||||
"chain = (\n",
|
||||
" {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
|
||||
" | prompt\n",
|
||||
" | llm\n",
|
||||
" | model\n",
|
||||
" | StrOutputParser()\n",
|
||||
")"
|
||||
]
|
||||
|
||||
@@ -192,7 +192,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": null,
|
||||
"id": "af3123ad-7a02-40e5-b58e-7d56e23e5830",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -203,7 +203,7 @@
|
||||
"# !pip install -qU langchain langchain-openai\n",
|
||||
"from langchain.chat_models import init_chat_model\n",
|
||||
"\n",
|
||||
"llm = init_chat_model(model=\"gpt-4o\", model_provider=\"openai\")"
|
||||
"model = init_chat_model(model=\"gpt-4o\", model_provider=\"openai\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -216,7 +216,7 @@
|
||||
"from langgraph.prebuilt import create_react_agent\n",
|
||||
"\n",
|
||||
"tools = [tool]\n",
|
||||
"agent = create_react_agent(llm, tools)"
|
||||
"agent = create_react_agent(model, tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -60,7 +60,7 @@ class Chat__ModuleName__(BaseChatModel):
|
||||
```python
|
||||
from __module_name__ import Chat__ModuleName__
|
||||
|
||||
llm = Chat__ModuleName__(
|
||||
model = Chat__ModuleName__(
|
||||
model="...",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
@@ -77,7 +77,7 @@ class Chat__ModuleName__(BaseChatModel):
|
||||
("system", "You are a helpful translator. Translate the user sentence to French."),
|
||||
("human", "I love programming."),
|
||||
]
|
||||
llm.invoke(messages)
|
||||
model.invoke(messages)
|
||||
```
|
||||
|
||||
```python
|
||||
@@ -87,7 +87,7 @@ class Chat__ModuleName__(BaseChatModel):
|
||||
# TODO: Delete if token-level streaming isn't supported.
|
||||
Stream:
|
||||
```python
|
||||
for chunk in llm.stream(messages):
|
||||
for chunk in model.stream(messages):
|
||||
print(chunk.text, end="")
|
||||
```
|
||||
|
||||
@@ -96,7 +96,7 @@ class Chat__ModuleName__(BaseChatModel):
|
||||
```
|
||||
|
||||
```python
|
||||
stream = llm.stream(messages)
|
||||
stream = model.stream(messages)
|
||||
full = next(stream)
|
||||
for chunk in stream:
|
||||
full += chunk
|
||||
@@ -110,13 +110,13 @@ class Chat__ModuleName__(BaseChatModel):
|
||||
# TODO: Delete if native async isn't supported.
|
||||
Async:
|
||||
```python
|
||||
await llm.ainvoke(messages)
|
||||
await model.ainvoke(messages)
|
||||
|
||||
# stream:
|
||||
# async for chunk in (await llm.astream(messages))
|
||||
# async for chunk in (await model.astream(messages))
|
||||
|
||||
# batch:
|
||||
# await llm.abatch([messages])
|
||||
# await model.abatch([messages])
|
||||
```
|
||||
|
||||
```python
|
||||
@@ -137,8 +137,8 @@ class Chat__ModuleName__(BaseChatModel):
|
||||
|
||||
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
||||
|
||||
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
|
||||
ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
|
||||
model_with_tools = model.bind_tools([GetWeather, GetPopulation])
|
||||
ai_msg = model_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
|
||||
ai_msg.tool_calls
|
||||
```
|
||||
|
||||
@@ -162,8 +162,8 @@ class Chat__ModuleName__(BaseChatModel):
|
||||
punchline: str = Field(description="The punchline to the joke")
|
||||
rating: int | None = Field(description="How funny the joke is, from 1 to 10")
|
||||
|
||||
structured_llm = llm.with_structured_output(Joke)
|
||||
structured_llm.invoke("Tell me a joke about cats")
|
||||
structured_model = model.with_structured_output(Joke)
|
||||
structured_model.invoke("Tell me a joke about cats")
|
||||
```
|
||||
|
||||
```python
|
||||
@@ -176,8 +176,8 @@ class Chat__ModuleName__(BaseChatModel):
|
||||
JSON mode:
|
||||
```python
|
||||
# TODO: Replace with appropriate bind arg.
|
||||
json_llm = llm.bind(response_format={"type": "json_object"})
|
||||
ai_msg = json_llm.invoke("Return a JSON object with key 'random_ints' and a value of 10 random ints in [0-99]")
|
||||
json_model = model.bind(response_format={"type": "json_object"})
|
||||
ai_msg = json_model.invoke("Return a JSON object with key 'random_ints' and a value of 10 random ints in [0-99]")
|
||||
ai_msg.content
|
||||
```
|
||||
|
||||
@@ -204,7 +204,7 @@ class Chat__ModuleName__(BaseChatModel):
|
||||
},
|
||||
],
|
||||
)
|
||||
ai_msg = llm.invoke([message])
|
||||
ai_msg = model.invoke([message])
|
||||
ai_msg.content
|
||||
```
|
||||
|
||||
@@ -235,7 +235,7 @@ class Chat__ModuleName__(BaseChatModel):
|
||||
# TODO: Delete if token usage metadata isn't supported.
|
||||
Token usage:
|
||||
```python
|
||||
ai_msg = llm.invoke(messages)
|
||||
ai_msg = model.invoke(messages)
|
||||
ai_msg.usage_metadata
|
||||
```
|
||||
|
||||
@@ -247,8 +247,8 @@ class Chat__ModuleName__(BaseChatModel):
|
||||
Logprobs:
|
||||
```python
|
||||
# TODO: Replace with appropriate bind arg.
|
||||
logprobs_llm = llm.bind(logprobs=True)
|
||||
ai_msg = logprobs_llm.invoke(messages)
|
||||
logprobs_model = model.bind(logprobs=True)
|
||||
ai_msg = logprobs_model.invoke(messages)
|
||||
ai_msg.response_metadata["logprobs"]
|
||||
```
|
||||
|
||||
@@ -257,7 +257,7 @@ class Chat__ModuleName__(BaseChatModel):
|
||||
```
|
||||
Response metadata
|
||||
```python
|
||||
ai_msg = llm.invoke(messages)
|
||||
ai_msg = model.invoke(messages)
|
||||
ai_msg.response_metadata
|
||||
```
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ class __ModuleName__Retriever(BaseRetriever):
|
||||
Question: {question}\"\"\"
|
||||
)
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
|
||||
model = ChatOpenAI(model="gpt-3.5-turbo-0125")
|
||||
|
||||
def format_docs(docs):
|
||||
return "\\n\\n".join(doc.page_content for doc in docs)
|
||||
@@ -73,7 +73,7 @@ class __ModuleName__Retriever(BaseRetriever):
|
||||
chain = (
|
||||
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| llm
|
||||
| model
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ def is_subclass(class_obj: type, classes_: list[type]) -> bool:
|
||||
classes_: A list of classes to check against.
|
||||
|
||||
Returns:
|
||||
True if `class_obj` is a subclass of any class in `classes_`, False otherwise.
|
||||
True if `class_obj` is a subclass of any class in `classes_`, `False` otherwise.
|
||||
"""
|
||||
return any(
|
||||
issubclass(class_obj, kls)
|
||||
|
||||
@@ -35,7 +35,7 @@ def is_openai_data_block(
|
||||
different type, this function will return False.
|
||||
|
||||
Returns:
|
||||
True if the block is a valid OpenAI data block and matches the filter_
|
||||
`True` if the block is a valid OpenAI data block and matches the filter_
|
||||
(if provided).
|
||||
|
||||
"""
|
||||
|
||||
@@ -123,7 +123,6 @@ class BaseLanguageModel(
|
||||
* If instance of `BaseCache`, will use the provided cache.
|
||||
|
||||
Caching is not currently supported for streaming methods of models.
|
||||
|
||||
"""
|
||||
verbose: bool = Field(default_factory=_get_verbosity, exclude=True, repr=False)
|
||||
"""Whether to print out response text."""
|
||||
@@ -146,7 +145,7 @@ class BaseLanguageModel(
|
||||
def set_verbose(cls, verbose: bool | None) -> bool: # noqa: FBT001
|
||||
"""If verbose is `None`, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
This allows users to pass in `None` as verbose to access the global setting.
|
||||
|
||||
Args:
|
||||
verbose: The verbosity setting to use.
|
||||
@@ -186,12 +185,12 @@ class BaseLanguageModel(
|
||||
1. Take advantage of batched calls,
|
||||
2. Need more output from the model than just the top generated value,
|
||||
3. Are building chains that are agnostic to the underlying language model
|
||||
type (e.g., pure text completion models vs chat models).
|
||||
type (e.g., pure text completion models vs chat models).
|
||||
|
||||
Args:
|
||||
prompts: List of PromptValues. A PromptValue is an object that can be
|
||||
converted to match the format of any language model (string for pure
|
||||
text generation models and BaseMessages for chat models).
|
||||
prompts: List of `PromptValue` objects. A `PromptValue` is an object that
|
||||
can be converted to match the format of any language model (string for
|
||||
pure text generation models and `BaseMessage` objects for chat models).
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
callbacks: Callbacks to pass through. Used for executing additional
|
||||
@@ -200,8 +199,8 @@ class BaseLanguageModel(
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
An `LLMResult`, which contains a list of candidate `Generation` objects for
|
||||
each input prompt and additional model provider-specific output.
|
||||
|
||||
"""
|
||||
|
||||
@@ -223,12 +222,12 @@ class BaseLanguageModel(
|
||||
1. Take advantage of batched calls,
|
||||
2. Need more output from the model than just the top generated value,
|
||||
3. Are building chains that are agnostic to the underlying language model
|
||||
type (e.g., pure text completion models vs chat models).
|
||||
type (e.g., pure text completion models vs chat models).
|
||||
|
||||
Args:
|
||||
prompts: List of PromptValues. A PromptValue is an object that can be
|
||||
converted to match the format of any language model (string for pure
|
||||
text generation models and BaseMessages for chat models).
|
||||
prompts: List of `PromptValue` objects. A `PromptValue` is an object that
|
||||
can be converted to match the format of any language model (string for
|
||||
pure text generation models and `BaseMessage` objects for chat models).
|
||||
stop: Stop words to use when generating. Model output is cut off at the
|
||||
first occurrence of any of these substrings.
|
||||
callbacks: Callbacks to pass through. Used for executing additional
|
||||
@@ -237,8 +236,8 @@ class BaseLanguageModel(
|
||||
to the model provider API call.
|
||||
|
||||
Returns:
|
||||
An `LLMResult`, which contains a list of candidate Generations for each
|
||||
input prompt and additional model provider-specific output.
|
||||
An `LLMResult`, which contains a list of candidate `Generation` objects for
|
||||
each input prompt and additional model provider-specific output.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@@ -240,78 +240,54 @@ def _format_ls_structured_output(ls_structured_output_format: dict | None) -> di
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
"""Base class for chat models.
|
||||
r"""Base class for chat models.
|
||||
|
||||
Key imperative methods:
|
||||
Methods that actually call the underlying model.
|
||||
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| Method | Input | Output | Description |
|
||||
+===========================+================================================================+=====================================================================+==================================================================================================+
|
||||
| `invoke` | str | list[dict | tuple | BaseMessage] | PromptValue | BaseMessage | A single chat model call. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `ainvoke` | ''' | BaseMessage | Defaults to running invoke in an async executor. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `stream` | ''' | Iterator[BaseMessageChunk] | Defaults to yielding output of invoke. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `astream` | ''' | AsyncIterator[BaseMessageChunk] | Defaults to yielding output of ainvoke. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `astream_events` | ''' | AsyncIterator[StreamEvent] | Event types: 'on_chat_model_start', 'on_chat_model_stream', 'on_chat_model_end'. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `batch` | list['''] | list[BaseMessage] | Defaults to running invoke in concurrent threads. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `abatch` | list['''] | list[BaseMessage] | Defaults to running ainvoke in concurrent threads. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `batch_as_completed` | list['''] | Iterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running invoke in concurrent threads. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
| `abatch_as_completed` | list['''] | AsyncIterator[tuple[int, Union[BaseMessage, Exception]]] | Defaults to running ainvoke in concurrent threads. |
|
||||
+---------------------------+----------------------------------------------------------------+---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+
|
||||
This table provides a brief overview of the main imperative methods. Please see the base `Runnable` reference for full documentation.
|
||||
|
||||
This table provides a brief overview of the main imperative methods. Please see the base Runnable reference for full documentation.
|
||||
| Method | Input | Output | Description |
|
||||
| ---------------------- | ------------------------------------------------------------ | ---------------------------------------------------------- | -------------------------------------------------------------------------------- |
|
||||
| `invoke` | `str` \| `list[dict | tuple | BaseMessage]` \| `PromptValue` | `BaseMessage` | A single chat model call. |
|
||||
| `ainvoke` | `'''` | `BaseMessage` | Defaults to running `invoke` in an async executor. |
|
||||
| `stream` | `'''` | `Iterator[BaseMessageChunk]` | Defaults to yielding output of `invoke`. |
|
||||
| `astream` | `'''` | `AsyncIterator[BaseMessageChunk]` | Defaults to yielding output of `ainvoke`. |
|
||||
| `astream_events` | `'''` | `AsyncIterator[StreamEvent]` | Event types: `on_chat_model_start`, `on_chat_model_stream`, `on_chat_model_end`. |
|
||||
| `batch` | `list[''']` | `list[BaseMessage]` | Defaults to running `invoke` in concurrent threads. |
|
||||
| `abatch` | `list[''']` | `list[BaseMessage]` | Defaults to running `ainvoke` in concurrent threads. |
|
||||
| `batch_as_completed` | `list[''']` | `Iterator[tuple[int, Union[BaseMessage, Exception]]]` | Defaults to running `invoke` in concurrent threads. |
|
||||
| `abatch_as_completed` | `list[''']` | `AsyncIterator[tuple[int, Union[BaseMessage, Exception]]]` | Defaults to running `ainvoke` in concurrent threads. |
|
||||
|
||||
Key declarative methods:
|
||||
Methods for creating another Runnable using the ChatModel.
|
||||
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
| Method | Description |
|
||||
+==================================+===========================================================================================================+
|
||||
| `bind_tools` | Create ChatModel that can call tools. |
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
| `with_structured_output` | Create wrapper that structures model output using schema. |
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
| `with_retry` | Create wrapper that retries model calls on failure. |
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
| `with_fallbacks` | Create wrapper that falls back to other models on failure. |
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
| `configurable_fields` | Specify init args of the model that can be configured at runtime via the RunnableConfig. |
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
| `configurable_alternatives` | Specify alternative models which can be swapped in at runtime via the RunnableConfig. |
|
||||
+----------------------------------+-----------------------------------------------------------------------------------------------------------+
|
||||
Methods for creating another `Runnable` using the chat model.
|
||||
|
||||
This table provides a brief overview of the main declarative methods. Please see the reference for each method for full documentation.
|
||||
|
||||
| Method | Description |
|
||||
| ---------------------------- | -------------------------------------------------------------------------------------------- |
|
||||
| `bind_tools` | Create chat model that can call tools. |
|
||||
| `with_structured_output` | Create wrapper that structures model output using schema. |
|
||||
| `with_retry` | Create wrapper that retries model calls on failure. |
|
||||
| `with_fallbacks` | Create wrapper that falls back to other models on failure. |
|
||||
| `configurable_fields` | Specify init args of the model that can be configured at runtime via the `RunnableConfig`. |
|
||||
| `configurable_alternatives` | Specify alternative models which can be swapped in at runtime via the `RunnableConfig`. |
|
||||
|
||||
Creating custom chat model:
|
||||
Custom chat model implementations should inherit from this class.
|
||||
Please reference the table below for information about which
|
||||
methods and properties are required or optional for implementations.
|
||||
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
| Method/Property | Description | Required/Optional |
|
||||
+==================================+====================================================================+===================+
|
||||
| -------------------------------- | ------------------------------------------------------------------ | ----------------- |
|
||||
| `_generate` | Use to generate a chat result from a prompt | Required |
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
| `_llm_type` (property) | Used to uniquely identify the type of the model. Used for logging. | Required |
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
| `_identifying_params` (property) | Represent model parameterization for tracing purposes. | Optional |
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
| `_stream` | Use to implement streaming | Optional |
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
| `_agenerate` | Use to implement a native async method | Optional |
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
| `_astream` | Use to implement async version of `_stream` | Optional |
|
||||
+----------------------------------+--------------------------------------------------------------------+-------------------+
|
||||
|
||||
Follow the guide for more information on how to implement a custom Chat Model:
|
||||
Follow the guide for more information on how to implement a custom chat model:
|
||||
[Guide](https://python.langchain.com/docs/how_to/custom_chat_model/).
|
||||
|
||||
""" # noqa: E501
|
||||
@@ -327,9 +303,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
|
||||
- If `True`, will always bypass streaming case.
|
||||
- If `'tool_calling'`, will bypass streaming case only when the model is called
|
||||
with a `tools` keyword argument. In other words, LangChain will automatically
|
||||
switch to non-streaming behavior (`invoke`) only when the tools argument is
|
||||
provided. This offers the best of both worlds.
|
||||
with a `tools` keyword argument. In other words, LangChain will automatically
|
||||
switch to non-streaming behavior (`invoke`) only when the tools argument is
|
||||
provided. This offers the best of both worlds.
|
||||
- If `False` (Default), will always use streaming case if available.
|
||||
|
||||
The main reason for this flag is that code might be written using `stream` and
|
||||
@@ -349,7 +325,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
Supported values:
|
||||
|
||||
- `'v0'`: provider-specific format in content (can lazily-parse with
|
||||
`.content_blocks`)
|
||||
`.content_blocks`)
|
||||
- `'v1'`: standardized format in content (consistent with `.content_blocks`)
|
||||
|
||||
Partner packages (e.g., `langchain-openai`) can also use this field to roll out
|
||||
@@ -1579,10 +1555,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
justification: str
|
||||
|
||||
|
||||
llm = ChatModel(model="model-name", temperature=0)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
model = ChatModel(model="model-name", temperature=0)
|
||||
structured_model = model.with_structured_output(AnswerWithJustification)
|
||||
|
||||
structured_llm.invoke(
|
||||
structured_model.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
)
|
||||
|
||||
@@ -1604,12 +1580,12 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
justification: str
|
||||
|
||||
|
||||
llm = ChatModel(model="model-name", temperature=0)
|
||||
structured_llm = llm.with_structured_output(
|
||||
model = ChatModel(model="model-name", temperature=0)
|
||||
structured_model = model.with_structured_output(
|
||||
AnswerWithJustification, include_raw=True
|
||||
)
|
||||
|
||||
structured_llm.invoke(
|
||||
structured_model.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
)
|
||||
# -> {
|
||||
@@ -1633,10 +1609,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
|
||||
|
||||
dict_schema = convert_to_openai_tool(AnswerWithJustification)
|
||||
llm = ChatModel(model="model-name", temperature=0)
|
||||
structured_llm = llm.with_structured_output(dict_schema)
|
||||
model = ChatModel(model="model-name", temperature=0)
|
||||
structured_model = model.with_structured_output(dict_schema)
|
||||
|
||||
structured_llm.invoke(
|
||||
structured_model.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
)
|
||||
# -> {
|
||||
|
||||
@@ -25,9 +25,9 @@ class BaseSerialized(TypedDict):
|
||||
id: list[str]
|
||||
"""The unique identifier of the object."""
|
||||
name: NotRequired[str]
|
||||
"""The name of the object. Optional."""
|
||||
"""The name of the object."""
|
||||
graph: NotRequired[dict[str, Any]]
|
||||
"""The graph of the object. Optional."""
|
||||
"""The graph of the object."""
|
||||
|
||||
|
||||
class SerializedConstructor(BaseSerialized):
|
||||
@@ -52,7 +52,7 @@ class SerializedNotImplemented(BaseSerialized):
|
||||
type: Literal["not_implemented"]
|
||||
"""The type of the object. Must be `'not_implemented'`."""
|
||||
repr: str | None
|
||||
"""The representation of the object. Optional."""
|
||||
"""The representation of the object."""
|
||||
|
||||
|
||||
def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
|
||||
@@ -61,7 +61,7 @@ def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
|
||||
Args:
|
||||
value: The value.
|
||||
key: The key.
|
||||
model: The pydantic model.
|
||||
model: The Pydantic model.
|
||||
|
||||
Returns:
|
||||
Whether the value is different from the default.
|
||||
@@ -93,18 +93,18 @@ class Serializable(BaseModel, ABC):
|
||||
It relies on the following methods and properties:
|
||||
|
||||
- `is_lc_serializable`: Is this class serializable?
|
||||
By design, even if a class inherits from Serializable, it is not serializable by
|
||||
default. This is to prevent accidental serialization of objects that should not
|
||||
be serialized.
|
||||
By design, even if a class inherits from `Serializable`, it is not serializable
|
||||
by default. This is to prevent accidental serialization of objects that should
|
||||
not be serialized.
|
||||
- `get_lc_namespace`: Get the namespace of the langchain object.
|
||||
During deserialization, this namespace is used to identify
|
||||
the correct class to instantiate.
|
||||
Please see the `Reviver` class in `langchain_core.load.load` for more details.
|
||||
During deserialization an additional mapping is handle
|
||||
classes that have moved or been renamed across package versions.
|
||||
During deserialization an additional mapping is handle classes that have moved
|
||||
or been renamed across package versions.
|
||||
- `lc_secrets`: A map of constructor argument names to secret ids.
|
||||
- `lc_attributes`: List of additional attribute names that should be included
|
||||
as part of the serialized representation.
|
||||
as part of the serialized representation.
|
||||
"""
|
||||
|
||||
# Remove default BaseModel init docstring.
|
||||
@@ -116,12 +116,12 @@ class Serializable(BaseModel, ABC):
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Is this class serializable?
|
||||
|
||||
By design, even if a class inherits from Serializable, it is not serializable by
|
||||
default. This is to prevent accidental serialization of objects that should not
|
||||
be serialized.
|
||||
By design, even if a class inherits from `Serializable`, it is not serializable
|
||||
by default. This is to prevent accidental serialization of objects that should
|
||||
not be serialized.
|
||||
|
||||
Returns:
|
||||
Whether the class is serializable. Default is False.
|
||||
Whether the class is serializable. Default is `False`.
|
||||
"""
|
||||
return False
|
||||
|
||||
@@ -133,7 +133,7 @@ class Serializable(BaseModel, ABC):
|
||||
namespace is ["langchain", "llms", "openai"]
|
||||
|
||||
Returns:
|
||||
The namespace as a list of strings.
|
||||
The namespace.
|
||||
"""
|
||||
return cls.__module__.split(".")
|
||||
|
||||
@@ -141,8 +141,7 @@ class Serializable(BaseModel, ABC):
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
"""A map of constructor argument names to secret ids.
|
||||
|
||||
For example,
|
||||
{"openai_api_key": "OPENAI_API_KEY"}
|
||||
For example, `{"openai_api_key": "OPENAI_API_KEY"}`
|
||||
"""
|
||||
return {}
|
||||
|
||||
@@ -151,6 +150,7 @@ class Serializable(BaseModel, ABC):
|
||||
"""List of attribute names that should be included in the serialized kwargs.
|
||||
|
||||
These attributes must be accepted by the constructor.
|
||||
|
||||
Default is an empty dictionary.
|
||||
"""
|
||||
return {}
|
||||
@@ -194,7 +194,7 @@ class Serializable(BaseModel, ABC):
|
||||
ValueError: If the class has deprecated attributes.
|
||||
|
||||
Returns:
|
||||
A json serializable object or a SerializedNotImplemented object.
|
||||
A json serializable object or a `SerializedNotImplemented` object.
|
||||
"""
|
||||
if not self.is_lc_serializable():
|
||||
return self.to_json_not_implemented()
|
||||
@@ -269,7 +269,7 @@ class Serializable(BaseModel, ABC):
|
||||
"""Serialize a "not implemented" object.
|
||||
|
||||
Returns:
|
||||
SerializedNotImplemented.
|
||||
`SerializedNotImplemented`.
|
||||
"""
|
||||
return to_json_not_implemented(self)
|
||||
|
||||
@@ -284,8 +284,8 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
|
||||
|
||||
Returns:
|
||||
Whether the field is useful. If the field is required, it is useful.
|
||||
If the field is not required, it is useful if the value is not None.
|
||||
If the field is not required and the value is None, it is useful if the
|
||||
If the field is not required, it is useful if the value is not `None`.
|
||||
If the field is not required and the value is `None`, it is useful if the
|
||||
default value is different from the value.
|
||||
"""
|
||||
field = type(inst).model_fields.get(key)
|
||||
@@ -344,10 +344,10 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
||||
"""Serialize a "not implemented" object.
|
||||
|
||||
Args:
|
||||
obj: object to serialize.
|
||||
obj: Object to serialize.
|
||||
|
||||
Returns:
|
||||
SerializedNotImplemented
|
||||
`SerializedNotImplemented`
|
||||
"""
|
||||
id_: list[str] = []
|
||||
try:
|
||||
|
||||
@@ -877,7 +877,7 @@ def is_data_content_block(block: dict) -> bool:
|
||||
block: The content block to check.
|
||||
|
||||
Returns:
|
||||
True if the content block is a data content block, False otherwise.
|
||||
`True` if the content block is a data content block, `False` otherwise.
|
||||
|
||||
"""
|
||||
if block.get("type") not in _get_data_content_block_types():
|
||||
|
||||
@@ -31,13 +31,13 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
|
||||
Args:
|
||||
result: The result of the LLM call.
|
||||
partial: Whether to parse partial JSON objects. Default is False.
|
||||
partial: Whether to parse partial JSON objects.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object.
|
||||
|
||||
Raises:
|
||||
OutputParserException: If the output is not valid JSON.
|
||||
`OutputParserException`: If the output is not valid JSON.
|
||||
"""
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
@@ -56,7 +56,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
|
||||
|
||||
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
"""Parse an output as the Json object."""
|
||||
"""Parse an output as the JSON object."""
|
||||
|
||||
strict: bool = False
|
||||
"""Whether to allow non-JSON-compliant strings.
|
||||
@@ -82,13 +82,13 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
|
||||
Args:
|
||||
result: The result of the LLM call.
|
||||
partial: Whether to parse partial JSON objects. Default is False.
|
||||
partial: Whether to parse partial JSON objects.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object.
|
||||
|
||||
Raises:
|
||||
OutputParserException: If the output is not valid JSON.
|
||||
OutputParserExcept`ion: If the output is not valid JSON.
|
||||
"""
|
||||
if len(result) != 1:
|
||||
msg = f"Expected exactly one result, but got {len(result)}"
|
||||
@@ -155,7 +155,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
|
||||
|
||||
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
"""Parse an output as the element of the Json object."""
|
||||
"""Parse an output as the element of the JSON object."""
|
||||
|
||||
key_name: str
|
||||
"""The name of the key to return."""
|
||||
@@ -165,7 +165,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
|
||||
Args:
|
||||
result: The result of the LLM call.
|
||||
partial: Whether to parse partial JSON objects. Default is False.
|
||||
partial: Whether to parse partial JSON objects.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object.
|
||||
@@ -177,16 +177,15 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
|
||||
|
||||
class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
"""Parse an output as a pydantic object.
|
||||
"""Parse an output as a Pydantic object.
|
||||
|
||||
This parser is used to parse the output of a ChatModel that uses
|
||||
OpenAI function format to invoke functions.
|
||||
This parser is used to parse the output of a chat model that uses OpenAI function
|
||||
format to invoke functions.
|
||||
|
||||
The parser extracts the function call invocation and matches
|
||||
them to the pydantic schema provided.
|
||||
The parser extracts the function call invocation and matches them to the Pydantic
|
||||
schema provided.
|
||||
|
||||
An exception will be raised if the function call does not match
|
||||
the provided schema.
|
||||
An exception will be raised if the function call does not match the provided schema.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -221,7 +220,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
"""
|
||||
|
||||
pydantic_schema: type[BaseModel] | dict[str, type[BaseModel]]
|
||||
"""The pydantic schema to parse the output with.
|
||||
"""The Pydantic schema to parse the output with.
|
||||
|
||||
If multiple schemas are provided, then the function name will be used to
|
||||
determine which schema to use.
|
||||
@@ -230,7 +229,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_schema(cls, values: dict) -> Any:
|
||||
"""Validate the pydantic schema.
|
||||
"""Validate the Pydantic schema.
|
||||
|
||||
Args:
|
||||
values: The values to validate.
|
||||
@@ -239,7 +238,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
The validated values.
|
||||
|
||||
Raises:
|
||||
ValueError: If the schema is not a pydantic schema.
|
||||
`ValueError`: If the schema is not a Pydantic schema.
|
||||
"""
|
||||
schema = values["pydantic_schema"]
|
||||
if "args_only" not in values:
|
||||
@@ -262,10 +261,10 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
|
||||
Args:
|
||||
result: The result of the LLM call.
|
||||
partial: Whether to parse partial JSON objects. Default is False.
|
||||
partial: Whether to parse partial JSON objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If the pydantic schema is not valid.
|
||||
`ValueError`: If the Pydantic schema is not valid.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object.
|
||||
@@ -288,13 +287,13 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
elif issubclass(pydantic_schema, BaseModelV1):
|
||||
pydantic_args = pydantic_schema.parse_raw(args)
|
||||
else:
|
||||
msg = f"Unsupported pydantic schema: {pydantic_schema}"
|
||||
msg = f"Unsupported Pydantic schema: {pydantic_schema}"
|
||||
raise ValueError(msg)
|
||||
return pydantic_args
|
||||
|
||||
|
||||
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
||||
"""Parse an output as an attribute of a pydantic object."""
|
||||
"""Parse an output as an attribute of a Pydantic object."""
|
||||
|
||||
attr_name: str
|
||||
"""The name of the attribute to return."""
|
||||
@@ -305,7 +304,7 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
||||
|
||||
Args:
|
||||
result: The result of the LLM call.
|
||||
partial: Whether to parse partial JSON objects. Default is False.
|
||||
partial: Whether to parse partial JSON objects.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object.
|
||||
|
||||
@@ -17,10 +17,10 @@ from langchain_core.utils.pydantic import (
|
||||
|
||||
|
||||
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
"""Parse an output using a pydantic model."""
|
||||
"""Parse an output using a Pydantic model."""
|
||||
|
||||
pydantic_object: Annotated[type[TBaseModel], SkipValidation()]
|
||||
"""The pydantic model to parse."""
|
||||
"""The Pydantic model to parse."""
|
||||
|
||||
def _parse_obj(self, obj: dict) -> TBaseModel:
|
||||
try:
|
||||
@@ -45,21 +45,20 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
def parse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
) -> TBaseModel | None:
|
||||
"""Parse the result of an LLM call to a pydantic object.
|
||||
"""Parse the result of an LLM call to a Pydantic object.
|
||||
|
||||
Args:
|
||||
result: The result of the LLM call.
|
||||
partial: Whether to parse partial JSON objects.
|
||||
If `True`, the output will be a JSON object containing
|
||||
all the keys that have been returned so far.
|
||||
Defaults to `False`.
|
||||
|
||||
Raises:
|
||||
OutputParserException: If the result is not valid JSON
|
||||
or does not conform to the pydantic model.
|
||||
`OutputParserException`: If the result is not valid JSON
|
||||
or does not conform to the Pydantic model.
|
||||
|
||||
Returns:
|
||||
The parsed pydantic object.
|
||||
The parsed Pydantic object.
|
||||
"""
|
||||
try:
|
||||
json_object = super().parse_result(result)
|
||||
@@ -70,13 +69,13 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
raise
|
||||
|
||||
def parse(self, text: str) -> TBaseModel:
|
||||
"""Parse the output of an LLM call to a pydantic object.
|
||||
"""Parse the output of an LLM call to a Pydantic object.
|
||||
|
||||
Args:
|
||||
text: The output of the LLM call.
|
||||
|
||||
Returns:
|
||||
The parsed pydantic object.
|
||||
The parsed Pydantic object.
|
||||
"""
|
||||
return super().parse(text)
|
||||
|
||||
@@ -107,7 +106,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[TBaseModel]:
|
||||
"""Return the pydantic model."""
|
||||
"""Return the Pydantic model."""
|
||||
return self.pydantic_object
|
||||
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ class LLMResult(BaseModel):
|
||||
other: Another `LLMResult` object to compare against.
|
||||
|
||||
Returns:
|
||||
True if the generations and `llm_output` are equal, False otherwise.
|
||||
`True` if the generations and `llm_output` are equal, `False` otherwise.
|
||||
"""
|
||||
if not isinstance(other, LLMResult):
|
||||
return NotImplemented
|
||||
|
||||
@@ -44,7 +44,7 @@ class BaseRateLimiter(abc.ABC):
|
||||
the attempt. Defaults to `True`.
|
||||
|
||||
Returns:
|
||||
True if the tokens were successfully acquired, False otherwise.
|
||||
`True` if the tokens were successfully acquired, `False` otherwise.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -63,7 +63,7 @@ class BaseRateLimiter(abc.ABC):
|
||||
the attempt. Defaults to `True`.
|
||||
|
||||
Returns:
|
||||
True if the tokens were successfully acquired, False otherwise.
|
||||
`True` if the tokens were successfully acquired, `False` otherwise.
|
||||
"""
|
||||
|
||||
|
||||
@@ -210,7 +210,7 @@ class InMemoryRateLimiter(BaseRateLimiter):
|
||||
the attempt. Defaults to `True`.
|
||||
|
||||
Returns:
|
||||
True if the tokens were successfully acquired, False otherwise.
|
||||
`True` if the tokens were successfully acquired, `False` otherwise.
|
||||
"""
|
||||
if not blocking:
|
||||
return self._consume()
|
||||
@@ -234,7 +234,7 @@ class InMemoryRateLimiter(BaseRateLimiter):
|
||||
the attempt. Defaults to `True`.
|
||||
|
||||
Returns:
|
||||
True if the tokens were successfully acquired, False otherwise.
|
||||
`True` if the tokens were successfully acquired, `False` otherwise.
|
||||
"""
|
||||
if not blocking:
|
||||
return self._consume()
|
||||
|
||||
@@ -7,7 +7,6 @@ the backbone of a retriever, but there are other types of retrievers as well.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -15,8 +14,6 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import Self, TypedDict, override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import (
|
||||
@@ -138,35 +135,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
@override
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Version upgrade for old retrievers that implemented the public
|
||||
# methods directly.
|
||||
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
|
||||
warnings.warn(
|
||||
"Retrievers must implement abstract `_get_relevant_documents` method"
|
||||
" instead of `get_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
swap = cls.get_relevant_documents
|
||||
cls.get_relevant_documents = ( # type: ignore[method-assign]
|
||||
BaseRetriever.get_relevant_documents
|
||||
)
|
||||
cls._get_relevant_documents = swap # type: ignore[method-assign]
|
||||
if (
|
||||
hasattr(cls, "aget_relevant_documents")
|
||||
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
|
||||
):
|
||||
warnings.warn(
|
||||
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
||||
" instead of `aget_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
aswap = cls.aget_relevant_documents
|
||||
cls.aget_relevant_documents = ( # type: ignore[method-assign]
|
||||
BaseRetriever.aget_relevant_documents
|
||||
)
|
||||
cls._aget_relevant_documents = aswap # type: ignore[method-assign]
|
||||
parameters = signature(cls._get_relevant_documents).parameters
|
||||
cls._new_arg_supported = parameters.get("run_manager") is not None
|
||||
if (
|
||||
@@ -348,91 +316,3 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
query,
|
||||
run_manager=run_manager.get_sync(),
|
||||
)
|
||||
|
||||
@deprecated(since="0.1.46", alternative="invoke", removal="1.0")
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
run_name: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
|
||||
Users should favor using `.invoke` or `.batch` rather than
|
||||
`get_relevant_documents directly`.
|
||||
|
||||
Args:
|
||||
query: string to find relevant documents for.
|
||||
callbacks: Callback manager or list of callbacks.
|
||||
tags: Optional list of tags associated with the retriever.
|
||||
These tags will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
|
||||
metadata: Optional metadata associated with the retriever.
|
||||
This metadata will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
|
||||
run_name: Optional name for the run.
|
||||
**kwargs: Additional arguments to pass to the retriever.
|
||||
|
||||
Returns:
|
||||
List of relevant documents.
|
||||
"""
|
||||
config: RunnableConfig = {}
|
||||
if callbacks:
|
||||
config["callbacks"] = callbacks
|
||||
if tags:
|
||||
config["tags"] = tags
|
||||
if metadata:
|
||||
config["metadata"] = metadata
|
||||
if run_name:
|
||||
config["run_name"] = run_name
|
||||
return self.invoke(query, config, **kwargs)
|
||||
|
||||
@deprecated(since="0.1.46", alternative="ainvoke", removal="1.0")
|
||||
async def aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
run_name: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
|
||||
Users should favor using `.ainvoke` or `.abatch` rather than
|
||||
`aget_relevant_documents directly`.
|
||||
|
||||
Args:
|
||||
query: string to find relevant documents for.
|
||||
callbacks: Callback manager or list of callbacks.
|
||||
tags: Optional list of tags associated with the retriever.
|
||||
These tags will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
|
||||
metadata: Optional metadata associated with the retriever.
|
||||
This metadata will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
|
||||
run_name: Optional name for the run.
|
||||
**kwargs: Additional arguments to pass to the retriever.
|
||||
|
||||
Returns:
|
||||
List of relevant documents.
|
||||
"""
|
||||
config: RunnableConfig = {}
|
||||
if callbacks:
|
||||
config["callbacks"] = callbacks
|
||||
if tags:
|
||||
config["tags"] = tags
|
||||
if metadata:
|
||||
config["metadata"] = metadata
|
||||
if run_name:
|
||||
config["run_name"] = run_name
|
||||
return await self.ainvoke(query, config, **kwargs)
|
||||
|
||||
@@ -304,7 +304,7 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
TypeError: If the input type cannot be inferred.
|
||||
"""
|
||||
# First loop through all parent classes and if any of them is
|
||||
# a pydantic model, we will pick up the generic parameterization
|
||||
# a Pydantic model, we will pick up the generic parameterization
|
||||
# from that model via the __pydantic_generic_metadata__ attribute.
|
||||
for base in self.__class__.mro():
|
||||
if hasattr(base, "__pydantic_generic_metadata__"):
|
||||
@@ -312,7 +312,7 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
if "args" in metadata and len(metadata["args"]) == 2:
|
||||
return metadata["args"][0]
|
||||
|
||||
# If we didn't find a pydantic model in the parent classes,
|
||||
# If we didn't find a Pydantic model in the parent classes,
|
||||
# then loop through __orig_bases__. This corresponds to
|
||||
# Runnables that are not pydantic models.
|
||||
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||
@@ -390,7 +390,7 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
self.get_name("Input"),
|
||||
root=root_type,
|
||||
# create model needs access to appropriate type annotations to be
|
||||
# able to construct the pydantic model.
|
||||
# able to construct the Pydantic model.
|
||||
# When we create the model, we pass information about the namespace
|
||||
# where the model is being created, so the type annotations can
|
||||
# be resolved correctly as well.
|
||||
@@ -433,7 +433,7 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
def output_schema(self) -> type[BaseModel]:
|
||||
"""Output schema.
|
||||
|
||||
The type of output this `Runnable` produces specified as a pydantic model.
|
||||
The type of output this `Runnable` produces specified as a Pydantic model.
|
||||
"""
|
||||
return self.get_output_schema()
|
||||
|
||||
@@ -468,7 +468,7 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
self.get_name("Output"),
|
||||
root=root_type,
|
||||
# create model needs access to appropriate type annotations to be
|
||||
# able to construct the pydantic model.
|
||||
# able to construct the Pydantic model.
|
||||
# When we create the model, we pass information about the namespace
|
||||
# where the model is being created, so the type annotations can
|
||||
# be resolved correctly as well.
|
||||
@@ -776,11 +776,11 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
model = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain: Runnable = prompt | llm | {"str": StrOutputParser()}
|
||||
chain: Runnable = prompt | model | {"str": StrOutputParser()}
|
||||
|
||||
chain_with_assign = chain.assign(hello=itemgetter("str") | llm)
|
||||
chain_with_assign = chain.assign(hello=itemgetter("str") | model)
|
||||
|
||||
print(chain_with_assign.input_schema.model_json_schema())
|
||||
# {'title': 'PromptInput', 'type': 'object', 'properties':
|
||||
@@ -1273,22 +1273,20 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
|
||||
A `StreamEvent` is a dictionary with the following schema:
|
||||
|
||||
- `event`: **str** - Event names are of the format:
|
||||
- `event`: Event names are of the format:
|
||||
`on_[runnable_type]_(start|stream|end)`.
|
||||
- `name`: **str** - The name of the `Runnable` that generated the event.
|
||||
- `run_id`: **str** - randomly generated ID associated with the given
|
||||
execution of the `Runnable` that emitted the event. A child `Runnable` that gets
|
||||
invoked as part of the execution of a parent `Runnable` is assigned its own
|
||||
unique ID.
|
||||
- `parent_ids`: **list[str]** - The IDs of the parent runnables that generated
|
||||
the event. The root `Runnable` will have an empty list. The order of the parent
|
||||
IDs is from the root to the immediate parent. Only available for v2 version of
|
||||
the API. The v1 version of the API will return an empty list.
|
||||
- `tags`: **list[str] | None** - The tags of the `Runnable` that generated
|
||||
the event.
|
||||
- `metadata`: **dict[str, Any] | None** - The metadata of the `Runnable` that
|
||||
generated the event.
|
||||
- `data`: **dict[str, Any]**
|
||||
- `name`: The name of the `Runnable` that generated the event.
|
||||
- `run_id`: Randomly generated ID associated with the given execution of the
|
||||
`Runnable` that emitted the event. A child `Runnable` that gets invoked as
|
||||
part of the execution of a parent `Runnable` is assigned its own unique ID.
|
||||
- `parent_ids`: The IDs of the parent runnables that generated the event. The
|
||||
root `Runnable` will have an empty list. The order of the parent IDs is from
|
||||
the root to the immediate parent. Only available for v2 version of the API.
|
||||
The v1 version of the API will return an empty list.
|
||||
- `tags`: The tags of the `Runnable` that generated the event.
|
||||
- `metadata`: The metadata of the `Runnable` that generated the event.
|
||||
- `data`: The data associated with the event. The contents of this field
|
||||
depend on the type of event. See the table below for more details.
|
||||
|
||||
Below is a table that illustrates some events that might be emitted by various
|
||||
chains. Metadata fields have been omitted from the table for brevity.
|
||||
@@ -1297,39 +1295,23 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
!!! note
|
||||
This reference table is for the v2 version of the schema.
|
||||
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| event | name | chunk | input | output |
|
||||
+==========================+==================+=====================================+===================================================+=====================================================+
|
||||
| `on_chat_model_start` | [model name] | | `{"messages": [[SystemMessage, HumanMessage]]}` | |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_chat_model_stream` | [model name] | `AIMessageChunk(content="hello")` | | |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_chat_model_end` | [model name] | | `{"messages": [[SystemMessage, HumanMessage]]}` | `AIMessageChunk(content="hello world")` |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_llm_start` | [model name] | | `{'input': 'hello'}` | |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_llm_stream` | [model name] | `'Hello' ` | | |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_llm_end` | [model name] | | `'Hello human!'` | |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_chain_start` | format_docs | | | |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_chain_stream` | format_docs | `'hello world!, goodbye world!'` | | |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_chain_end` | format_docs | | `[Document(...)]` | `'hello world!, goodbye world!'` |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_tool_start` | some_tool | | `{"x": 1, "y": "2"}` | |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_tool_end` | some_tool | | | `{"x": 1, "y": "2"}` |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_retriever_start` | [retriever name] | | `{"query": "hello"}` | |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_retriever_end` | [retriever name] | | `{"query": "hello"}` | `[Document(...), ..]` |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_prompt_start` | [template_name] | | `{"question": "hello"}` | |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| `on_prompt_end` | [template_name] | | `{"question": "hello"}` | `ChatPromptValue(messages: [SystemMessage, ...])` |
|
||||
+--------------------------+------------------+-------------------------------------+---------------------------------------------------+-----------------------------------------------------+
|
||||
| event | name | chunk | input | output |
|
||||
| ---------------------- | -------------------- | ----------------------------------- | ------------------------------------------------- | --------------------------------------------------- |
|
||||
| `on_chat_model_start` | `'[model name]'` | | `{"messages": [[SystemMessage, HumanMessage]]}` | |
|
||||
| `on_chat_model_stream` | `'[model name]'` | `AIMessageChunk(content="hello")` | | |
|
||||
| `on_chat_model_end` | `'[model name]'` | | `{"messages": [[SystemMessage, HumanMessage]]}` | `AIMessageChunk(content="hello world")` |
|
||||
| `on_llm_start` | `'[model name]'` | | `{'input': 'hello'}` | |
|
||||
| `on_llm_stream` | `'[model name]'` | `'Hello' ` | | |
|
||||
| `on_llm_end` | `'[model name]'` | | `'Hello human!'` | |
|
||||
| `on_chain_start` | `'format_docs'` | | | |
|
||||
| `on_chain_stream` | `'format_docs'` | `'hello world!, goodbye world!'` | | |
|
||||
| `on_chain_end` | `'format_docs'` | | `[Document(...)]` | `'hello world!, goodbye world!'` |
|
||||
| `on_tool_start` | `'some_tool'` | | `{"x": 1, "y": "2"}` | |
|
||||
| `on_tool_end` | `'some_tool'` | | | `{"x": 1, "y": "2"}` |
|
||||
| `on_retriever_start` | `'[retriever name]'` | | `{"query": "hello"}` | |
|
||||
| `on_retriever_end` | `'[retriever name]'` | | `{"query": "hello"}` | `[Document(...), ..]` |
|
||||
| `on_prompt_start` | `'[template_name]'` | | `{"question": "hello"}` | |
|
||||
| `on_prompt_end` | `'[template_name]'` | | `{"question": "hello"}` | `ChatPromptValue(messages: [SystemMessage, ...])` |
|
||||
|
||||
In addition to the standard events, users can also dispatch custom events (see example below).
|
||||
|
||||
@@ -1337,13 +1319,10 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
|
||||
A custom event has following format:
|
||||
|
||||
+-----------+------+-----------------------------------------------------------------------------------------------------------+
|
||||
| Attribute | Type | Description |
|
||||
+===========+======+===========================================================================================================+
|
||||
| name | str | A user defined name for the event. |
|
||||
+-----------+------+-----------------------------------------------------------------------------------------------------------+
|
||||
| data | Any | The data associated with the event. This can be anything, though we suggest making it JSON serializable. |
|
||||
+-----------+------+-----------------------------------------------------------------------------------------------------------+
|
||||
| Attribute | Type | Description |
|
||||
| ----------- | ------ | --------------------------------------------------------------------------------------------------------- |
|
||||
| `name` | `str` | A user defined name for the event. |
|
||||
| `data` | `Any` | The data associated with the event. This can be anything, though we suggest making it JSON serializable. |
|
||||
|
||||
Here are declarations associated with the standard events shown above:
|
||||
|
||||
@@ -1619,16 +1598,16 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
llm = ChatOllama(model="llama3.1")
|
||||
model = ChatOllama(model="llama3.1")
|
||||
|
||||
# Without bind
|
||||
chain = llm | StrOutputParser()
|
||||
chain = model | StrOutputParser()
|
||||
|
||||
chain.invoke("Repeat quoted words exactly: 'One two three four five.'")
|
||||
# Output is 'One two three four five.'
|
||||
|
||||
# With bind
|
||||
chain = llm.bind(stop=["three"]) | StrOutputParser()
|
||||
chain = model.bind(stop=["three"]) | StrOutputParser()
|
||||
|
||||
chain.invoke("Repeat quoted words exactly: 'One two three four five.'")
|
||||
# Output is 'One two'
|
||||
@@ -4493,7 +4472,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
@override
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
"""The pydantic schema for the input to this `Runnable`.
|
||||
"""The Pydantic schema for the input to this `Runnable`.
|
||||
|
||||
Args:
|
||||
config: The config to use.
|
||||
@@ -5127,7 +5106,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
|
||||
None,
|
||||
),
|
||||
# create model needs access to appropriate type annotations to be
|
||||
# able to construct the pydantic model.
|
||||
# able to construct the Pydantic model.
|
||||
# When we create the model, we pass information about the namespace
|
||||
# where the model is being created, so the type annotations can
|
||||
# be resolved correctly as well.
|
||||
@@ -5150,7 +5129,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]):
|
||||
self.get_name("Output"),
|
||||
root=list[schema], # type: ignore[valid-type]
|
||||
# create model needs access to appropriate type annotations to be
|
||||
# able to construct the pydantic model.
|
||||
# able to construct the Pydantic model.
|
||||
# When we create the model, we pass information about the namespace
|
||||
# where the model is being created, so the type annotations can
|
||||
# be resolved correctly as well.
|
||||
@@ -5387,13 +5366,13 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[
|
||||
custom_input_type: Any | None = None
|
||||
"""Override the input type of the underlying `Runnable` with a custom type.
|
||||
|
||||
The type can be a pydantic model, or a type annotation (e.g., `list[str]`).
|
||||
The type can be a Pydantic model, or a type annotation (e.g., `list[str]`).
|
||||
"""
|
||||
# Union[Type[Output], BaseModel] + things like list[str]
|
||||
custom_output_type: Any | None = None
|
||||
"""Override the output type of the underlying `Runnable` with a custom type.
|
||||
|
||||
The type can be a pydantic model, or a type annotation (e.g., `list[str]`).
|
||||
The type can be a Pydantic model, or a type annotation (e.g., `list[str]`).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -6077,10 +6056,10 @@ def chain(
|
||||
@chain
|
||||
def my_func(fields):
|
||||
prompt = PromptTemplate("Hello, {name}!")
|
||||
llm = OpenAI()
|
||||
model = OpenAI()
|
||||
formatted = prompt.invoke(**fields)
|
||||
|
||||
for chunk in llm.stream(formatted):
|
||||
for chunk in model.stream(formatted):
|
||||
yield chunk
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -594,7 +594,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
Returns:
|
||||
If the attribute is anything other than a method that outputs a Runnable,
|
||||
returns getattr(self.runnable, name). If the attribute is a method that
|
||||
does return a new Runnable (e.g. llm.bind_tools([...]) outputs a new
|
||||
does return a new Runnable (e.g. model.bind_tools([...]) outputs a new
|
||||
RunnableBinding) then self.runnable and each of the runnables in
|
||||
self.fallbacks is replaced with getattr(x, name).
|
||||
|
||||
@@ -605,15 +605,15 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
|
||||
gpt_4o = ChatOpenAI(model="gpt-4o")
|
||||
claude_3_sonnet = ChatAnthropic(model="claude-3-7-sonnet-20250219")
|
||||
llm = gpt_4o.with_fallbacks([claude_3_sonnet])
|
||||
model = gpt_4o.with_fallbacks([claude_3_sonnet])
|
||||
|
||||
llm.model_name
|
||||
model.model_name
|
||||
# -> "gpt-4o"
|
||||
|
||||
# .bind_tools() is called on both ChatOpenAI and ChatAnthropic
|
||||
# Equivalent to:
|
||||
# gpt_4o.bind_tools([...]).with_fallbacks([claude_3_sonnet.bind_tools([...])])
|
||||
llm.bind_tools([...])
|
||||
model.bind_tools([...])
|
||||
# -> RunnableWithFallbacks(
|
||||
runnable=RunnableBinding(bound=ChatOpenAI(...), kwargs={"tools": [...]}),
|
||||
fallbacks=[RunnableBinding(bound=ChatAnthropic(...), kwargs={"tools": [...]})],
|
||||
|
||||
@@ -52,7 +52,7 @@ def is_uuid(value: str) -> bool:
|
||||
value: The string to check.
|
||||
|
||||
Returns:
|
||||
True if the string is a valid UUID, False otherwise.
|
||||
`True` if the string is a valid UUID, `False` otherwise.
|
||||
"""
|
||||
try:
|
||||
UUID(value)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Module contains typedefs that are used with Runnables."""
|
||||
"""Module contains typedefs that are used with `Runnable` objects."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -14,43 +14,43 @@ class EventData(TypedDict, total=False):
|
||||
"""Data associated with a streaming event."""
|
||||
|
||||
input: Any
|
||||
"""The input passed to the Runnable that generated the event.
|
||||
"""The input passed to the `Runnable` that generated the event.
|
||||
|
||||
Inputs will sometimes be available at the *START* of the Runnable, and
|
||||
sometimes at the *END* of the Runnable.
|
||||
Inputs will sometimes be available at the *START* of the `Runnable`, and
|
||||
sometimes at the *END* of the `Runnable`.
|
||||
|
||||
If a Runnable is able to stream its inputs, then its input by definition
|
||||
won't be known until the *END* of the Runnable when it has finished streaming
|
||||
If a `Runnable` is able to stream its inputs, then its input by definition
|
||||
won't be known until the *END* of the `Runnable` when it has finished streaming
|
||||
its inputs.
|
||||
"""
|
||||
error: NotRequired[BaseException]
|
||||
"""The error that occurred during the execution of the Runnable.
|
||||
"""The error that occurred during the execution of the `Runnable`.
|
||||
|
||||
This field is only available if the Runnable raised an exception.
|
||||
This field is only available if the `Runnable` raised an exception.
|
||||
|
||||
!!! version-added "Added in version 1.0.0"
|
||||
"""
|
||||
output: Any
|
||||
"""The output of the Runnable that generated the event.
|
||||
"""The output of the `Runnable` that generated the event.
|
||||
|
||||
Outputs will only be available at the *END* of the Runnable.
|
||||
Outputs will only be available at the *END* of the `Runnable`.
|
||||
|
||||
For most Runnables, this field can be inferred from the `chunk` field,
|
||||
though there might be some exceptions for special cased Runnables (e.g., like
|
||||
For most `Runnable` objects, this field can be inferred from the `chunk` field,
|
||||
though there might be some exceptions for special a cased `Runnable` (e.g., like
|
||||
chat models), which may return more information.
|
||||
"""
|
||||
chunk: Any
|
||||
"""A streaming chunk from the output that generated the event.
|
||||
|
||||
chunks support addition in general, and adding them up should result
|
||||
in the output of the Runnable that generated the event.
|
||||
in the output of the `Runnable` that generated the event.
|
||||
"""
|
||||
|
||||
|
||||
class BaseStreamEvent(TypedDict):
|
||||
"""Streaming event.
|
||||
|
||||
Schema of a streaming event which is produced from the astream_events method.
|
||||
Schema of a streaming event which is produced from the `astream_events` method.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -94,45 +94,45 @@ class BaseStreamEvent(TypedDict):
|
||||
"""
|
||||
|
||||
event: str
|
||||
"""Event names are of the format: on_[runnable_type]_(start|stream|end).
|
||||
"""Event names are of the format: `on_[runnable_type]_(start|stream|end)`.
|
||||
|
||||
Runnable types are one of:
|
||||
|
||||
- **llm** - used by non chat models
|
||||
- **chat_model** - used by chat models
|
||||
- **prompt** -- e.g., ChatPromptTemplate
|
||||
- **tool** -- from tools defined via @tool decorator or inheriting
|
||||
from Tool/BaseTool
|
||||
- **chain** - most Runnables are of this type
|
||||
- **prompt** -- e.g., `ChatPromptTemplate`
|
||||
- **tool** -- from tools defined via `@tool` decorator or inheriting
|
||||
from `Tool`/`BaseTool`
|
||||
- **chain** - most `Runnable` objects are of this type
|
||||
|
||||
Further, the events are categorized as one of:
|
||||
|
||||
- **start** - when the Runnable starts
|
||||
- **stream** - when the Runnable is streaming
|
||||
- **end* - when the Runnable ends
|
||||
- **start** - when the `Runnable` starts
|
||||
- **stream** - when the `Runnable` is streaming
|
||||
- **end* - when the `Runnable` ends
|
||||
|
||||
start, stream and end are associated with slightly different `data` payload.
|
||||
|
||||
Please see the documentation for `EventData` for more details.
|
||||
"""
|
||||
run_id: str
|
||||
"""An randomly generated ID to keep track of the execution of the given Runnable.
|
||||
"""An randomly generated ID to keep track of the execution of the given `Runnable`.
|
||||
|
||||
Each child Runnable that gets invoked as part of the execution of a parent Runnable
|
||||
is assigned its own unique ID.
|
||||
Each child `Runnable` that gets invoked as part of the execution of a parent
|
||||
`Runnable` is assigned its own unique ID.
|
||||
"""
|
||||
tags: NotRequired[list[str]]
|
||||
"""Tags associated with the Runnable that generated this event.
|
||||
"""Tags associated with the `Runnable` that generated this event.
|
||||
|
||||
Tags are always inherited from parent Runnables.
|
||||
Tags are always inherited from parent `Runnable` objects.
|
||||
|
||||
Tags can either be bound to a Runnable using `.with_config({"tags": ["hello"]})`
|
||||
Tags can either be bound to a `Runnable` using `.with_config({"tags": ["hello"]})`
|
||||
or passed at run time using `.astream_events(..., {"tags": ["hello"]})`.
|
||||
"""
|
||||
metadata: NotRequired[dict[str, Any]]
|
||||
"""Metadata associated with the Runnable that generated this event.
|
||||
"""Metadata associated with the `Runnable` that generated this event.
|
||||
|
||||
Metadata can either be bound to a Runnable using
|
||||
Metadata can either be bound to a `Runnable` using
|
||||
|
||||
`.with_config({"metadata": { "foo": "bar" }})`
|
||||
|
||||
@@ -146,8 +146,8 @@ class BaseStreamEvent(TypedDict):
|
||||
|
||||
Root Events will have an empty list.
|
||||
|
||||
For example, if a Runnable A calls Runnable B, then the event generated by Runnable
|
||||
B will have Runnable A's ID in the parent_ids field.
|
||||
For example, if a `Runnable` A calls `Runnable` B, then the event generated by
|
||||
`Runnable` B will have `Runnable` A's ID in the `parent_ids` field.
|
||||
|
||||
The order of the parent IDs is from the root parent to the immediate parent.
|
||||
|
||||
@@ -164,7 +164,7 @@ class StandardStreamEvent(BaseStreamEvent):
|
||||
The contents of the event data depend on the event type.
|
||||
"""
|
||||
name: str
|
||||
"""The name of the Runnable that generated the event."""
|
||||
"""The name of the `Runnable` that generated the event."""
|
||||
|
||||
|
||||
class CustomStreamEvent(BaseStreamEvent):
|
||||
|
||||
@@ -80,7 +80,7 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool: # noqa: A002
|
||||
callable: The callable to check.
|
||||
|
||||
Returns:
|
||||
True if the callable accepts a run_manager argument, False otherwise.
|
||||
`True` if the callable accepts a run_manager argument, `False` otherwise.
|
||||
"""
|
||||
try:
|
||||
return signature(callable).parameters.get("run_manager") is not None
|
||||
@@ -95,7 +95,7 @@ def accepts_config(callable: Callable[..., Any]) -> bool: # noqa: A002
|
||||
callable: The callable to check.
|
||||
|
||||
Returns:
|
||||
True if the callable accepts a config argument, False otherwise.
|
||||
`True` if the callable accepts a config argument, `False` otherwise.
|
||||
"""
|
||||
try:
|
||||
return signature(callable).parameters.get("config") is not None
|
||||
@@ -110,7 +110,7 @@ def accepts_context(callable: Callable[..., Any]) -> bool: # noqa: A002
|
||||
callable: The callable to check.
|
||||
|
||||
Returns:
|
||||
True if the callable accepts a context argument, False otherwise.
|
||||
`True` if the callable accepts a context argument, `False` otherwise.
|
||||
"""
|
||||
try:
|
||||
return signature(callable).parameters.get("context") is not None
|
||||
@@ -123,7 +123,7 @@ def asyncio_accepts_context() -> bool:
|
||||
"""Cache the result of checking if asyncio.create_task accepts a `context` arg.
|
||||
|
||||
Returns:
|
||||
True if `asyncio.create_task` accepts a context argument, False otherwise.
|
||||
True if `asyncio.create_task` accepts a context argument, `False` otherwise.
|
||||
"""
|
||||
return accepts_context(asyncio.create_task)
|
||||
|
||||
@@ -727,7 +727,7 @@ def is_async_generator(
|
||||
func: The function to check.
|
||||
|
||||
Returns:
|
||||
True if the function is an async generator, False otherwise.
|
||||
`True` if the function is an async generator, `False` otherwise.
|
||||
"""
|
||||
return inspect.isasyncgenfunction(func) or (
|
||||
hasattr(func, "__call__") # noqa: B004
|
||||
@@ -744,7 +744,7 @@ def is_async_callable(
|
||||
func: The function to check.
|
||||
|
||||
Returns:
|
||||
True if the function is async, False otherwise.
|
||||
`True` if the function is async, `False` otherwise.
|
||||
"""
|
||||
return asyncio.iscoroutinefunction(func) or (
|
||||
hasattr(func, "__call__") # noqa: B004
|
||||
|
||||
@@ -92,7 +92,7 @@ def _is_annotated_type(typ: type[Any]) -> bool:
|
||||
typ: The type to check.
|
||||
|
||||
Returns:
|
||||
True if the type is an Annotated type, False otherwise.
|
||||
`True` if the type is an Annotated type, `False` otherwise.
|
||||
"""
|
||||
return get_origin(typ) is typing.Annotated
|
||||
|
||||
@@ -226,7 +226,7 @@ def _is_pydantic_annotation(annotation: Any, pydantic_version: str = "v2") -> bo
|
||||
pydantic_version: The Pydantic version to check against ("v1" or "v2").
|
||||
|
||||
Returns:
|
||||
True if the annotation is a Pydantic model, False otherwise.
|
||||
`True` if the annotation is a Pydantic model, `False` otherwise.
|
||||
"""
|
||||
base_model_class = BaseModelV1 if pydantic_version == "v1" else BaseModel
|
||||
try:
|
||||
@@ -245,7 +245,7 @@ def _function_annotations_are_pydantic_v1(
|
||||
func: The function being checked.
|
||||
|
||||
Returns:
|
||||
True if all Pydantic annotations are from V1, False otherwise.
|
||||
True if all Pydantic annotations are from V1, `False` otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the function contains mixed V1 and V2 annotations.
|
||||
@@ -285,17 +285,17 @@ def create_schema_from_function(
|
||||
error_on_invalid_docstring: bool = False,
|
||||
include_injected: bool = True,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature.
|
||||
"""Create a Pydantic schema from a function's signature.
|
||||
|
||||
Args:
|
||||
model_name: Name to assign to the generated pydantic schema.
|
||||
model_name: Name to assign to the generated Pydantic schema.
|
||||
func: Function to generate the schema from.
|
||||
filter_args: Optional list of arguments to exclude from the schema.
|
||||
Defaults to FILTERED_ARGS.
|
||||
Defaults to `FILTERED_ARGS`.
|
||||
parse_docstring: Whether to parse the function's docstring for descriptions
|
||||
for each argument. Defaults to `False`.
|
||||
error_on_invalid_docstring: if `parse_docstring` is provided, configure
|
||||
whether to raise ValueError on invalid Google Style docstrings.
|
||||
whether to raise `ValueError` on invalid Google Style docstrings.
|
||||
Defaults to `False`.
|
||||
include_injected: Whether to include injected arguments in the schema.
|
||||
Defaults to `True`, since we want to include them in the schema
|
||||
@@ -312,7 +312,7 @@ def create_schema_from_function(
|
||||
# https://docs.pydantic.dev/latest/usage/validation_decorator/
|
||||
with warnings.catch_warnings():
|
||||
# We are using deprecated functionality here.
|
||||
# This code should be re-written to simply construct a pydantic model
|
||||
# This code should be re-written to simply construct a Pydantic model
|
||||
# using inspect.signature and create_model.
|
||||
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
|
||||
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore[operator]
|
||||
@@ -517,7 +517,7 @@ class ChildTool(BaseTool):
|
||||
"""Check if the tool accepts only a single input argument.
|
||||
|
||||
Returns:
|
||||
True if the tool has only one input argument, False otherwise.
|
||||
`True` if the tool has only one input argument, `False` otherwise.
|
||||
"""
|
||||
keys = {k for k in self.args if k != "kwargs"}
|
||||
return len(keys) == 1
|
||||
@@ -981,7 +981,7 @@ def _is_tool_call(x: Any) -> bool:
|
||||
x: The input to check.
|
||||
|
||||
Returns:
|
||||
True if the input is a tool call, False otherwise.
|
||||
`True` if the input is a tool call, `False` otherwise.
|
||||
"""
|
||||
return isinstance(x, dict) and x.get("type") == "tool_call"
|
||||
|
||||
@@ -1128,7 +1128,7 @@ def _is_message_content_type(obj: Any) -> bool:
|
||||
obj: The object to check.
|
||||
|
||||
Returns:
|
||||
True if the object is valid message content, False otherwise.
|
||||
`True` if the object is valid message content, `False` otherwise.
|
||||
"""
|
||||
return isinstance(obj, str) or (
|
||||
isinstance(obj, list) and all(_is_message_content_block(e) for e in obj)
|
||||
@@ -1144,7 +1144,7 @@ def _is_message_content_block(obj: Any) -> bool:
|
||||
obj: The object to check.
|
||||
|
||||
Returns:
|
||||
True if the object is a valid content block, False otherwise.
|
||||
`True` if the object is a valid content block, `False` otherwise.
|
||||
"""
|
||||
if isinstance(obj, str):
|
||||
return True
|
||||
@@ -1248,7 +1248,7 @@ def _is_injected_arg_type(
|
||||
injected_type: The specific injected type to check for.
|
||||
|
||||
Returns:
|
||||
True if the type is an injected argument, False otherwise.
|
||||
`True` if the type is an injected argument, `False` otherwise.
|
||||
"""
|
||||
injected_type = injected_type or InjectedToolArg
|
||||
return any(
|
||||
|
||||
@@ -69,7 +69,7 @@ class Tool(BaseTool):
|
||||
def _to_args_and_kwargs(
|
||||
self, tool_input: str | dict, tool_call_id: str | None
|
||||
) -> tuple[tuple, dict]:
|
||||
"""Convert tool input to pydantic model.
|
||||
"""Convert tool input to Pydantic model.
|
||||
|
||||
Args:
|
||||
tool_input: The input to the tool.
|
||||
@@ -79,8 +79,7 @@ class Tool(BaseTool):
|
||||
ToolException: If the tool input is invalid.
|
||||
|
||||
Returns:
|
||||
the pydantic model args and kwargs.
|
||||
|
||||
The Pydantic model args and kwargs.
|
||||
"""
|
||||
args, kwargs = super()._to_args_and_kwargs(tool_input, tool_call_id)
|
||||
# For backwards compatibility. The tool must be run with a single input
|
||||
|
||||
@@ -190,7 +190,7 @@ class RunLog(RunLogPatch):
|
||||
other: The other `RunLog` to compare to.
|
||||
|
||||
Returns:
|
||||
True if the `RunLog`s are equal, False otherwise.
|
||||
`True` if the `RunLog`s are equal, `False` otherwise.
|
||||
"""
|
||||
# First compare that the state is the same
|
||||
if not isinstance(other, RunLog):
|
||||
@@ -288,7 +288,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
*ops: The operations to send to the stream.
|
||||
|
||||
Returns:
|
||||
True if the patch was sent successfully, False if the stream is closed.
|
||||
`True` if the patch was sent successfully, False if the stream is closed.
|
||||
"""
|
||||
# We will likely want to wrap this in try / except at some point
|
||||
# to handle exceptions that might arise at run time.
|
||||
@@ -368,7 +368,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
|
||||
run: The Run to check.
|
||||
|
||||
Returns:
|
||||
True if the run should be included, False otherwise.
|
||||
`True` if the run should be included, `False` otherwise.
|
||||
"""
|
||||
if run.id == self.root_id:
|
||||
return False
|
||||
|
||||
@@ -13,7 +13,7 @@ def env_var_is_set(env_var: str) -> bool:
|
||||
env_var: The name of the environment variable.
|
||||
|
||||
Returns:
|
||||
True if the environment variable is set, False otherwise.
|
||||
`True` if the environment variable is set, `False` otherwise.
|
||||
"""
|
||||
return env_var in os.environ and os.environ[env_var] not in {
|
||||
"",
|
||||
|
||||
@@ -713,7 +713,7 @@ def tool_example_to_messages(
|
||||
"type": "function",
|
||||
"function": {
|
||||
# The name of the function right now corresponds to the name
|
||||
# of the pydantic model. This is implicit in the API right now,
|
||||
# of the Pydantic model. This is implicit in the API right now,
|
||||
# and will be improved over time.
|
||||
"name": tool_call.__class__.__name__,
|
||||
"arguments": tool_call.model_dump_json(),
|
||||
|
||||
@@ -7,6 +7,6 @@ def is_interactive_env() -> bool:
|
||||
"""Determine if running within IPython or Jupyter.
|
||||
|
||||
Returns:
|
||||
True if running in an interactive environment, False otherwise.
|
||||
True if running in an interactive environment, `False` otherwise.
|
||||
"""
|
||||
return hasattr(sys, "ps2")
|
||||
|
||||
@@ -78,7 +78,7 @@ def is_pydantic_v1_subclass(cls: type) -> bool:
|
||||
"""Check if the given class is Pydantic v1-like.
|
||||
|
||||
Returns:
|
||||
True if the given class is a subclass of Pydantic `BaseModel` 1.x.
|
||||
`True` if the given class is a subclass of Pydantic `BaseModel` 1.x.
|
||||
"""
|
||||
return issubclass(cls, BaseModelV1)
|
||||
|
||||
@@ -87,7 +87,7 @@ def is_pydantic_v2_subclass(cls: type) -> bool:
|
||||
"""Check if the given class is Pydantic v2-like.
|
||||
|
||||
Returns:
|
||||
True if the given class is a subclass of Pydantic BaseModel 2.x.
|
||||
`True` if the given class is a subclass of Pydantic BaseModel 2.x.
|
||||
"""
|
||||
return issubclass(cls, BaseModel)
|
||||
|
||||
@@ -101,7 +101,7 @@ def is_basemodel_subclass(cls: type) -> bool:
|
||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||
|
||||
Returns:
|
||||
True if the given class is a subclass of Pydantic `BaseModel`.
|
||||
`True` if the given class is a subclass of Pydantic `BaseModel`.
|
||||
"""
|
||||
# Before we can use issubclass on the cls we need to check if it is a class
|
||||
if not inspect.isclass(cls) or isinstance(cls, GenericAlias):
|
||||
@@ -119,7 +119,7 @@ def is_basemodel_instance(obj: Any) -> bool:
|
||||
* pydantic.v1.BaseModel in Pydantic 2.x
|
||||
|
||||
Returns:
|
||||
True if the given class is an instance of Pydantic `BaseModel`.
|
||||
`True` if the given class is an instance of Pydantic `BaseModel`.
|
||||
"""
|
||||
return isinstance(obj, (BaseModel, BaseModelV1))
|
||||
|
||||
@@ -206,7 +206,7 @@ def _create_subset_model_v1(
|
||||
descriptions: dict | None = None,
|
||||
fn_description: str | None = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
"""Create a Pydantic model with only a subset of model's fields."""
|
||||
fields = {}
|
||||
|
||||
for field_name in field_names:
|
||||
@@ -235,7 +235,7 @@ def _create_subset_model_v2(
|
||||
descriptions: dict | None = None,
|
||||
fn_description: str | None = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with a subset of the model fields."""
|
||||
"""Create a Pydantic model with a subset of the model fields."""
|
||||
descriptions_ = descriptions or {}
|
||||
fields = {}
|
||||
for field_name in field_names:
|
||||
@@ -438,9 +438,9 @@ def create_model(
|
||||
/,
|
||||
**field_definitions: Any,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with the given field definitions.
|
||||
"""Create a Pydantic model with the given field definitions.
|
||||
|
||||
Please use create_model_v2 instead of this function.
|
||||
Please use `create_model_v2` instead of this function.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model.
|
||||
@@ -511,7 +511,7 @@ def create_model_v2(
|
||||
field_definitions: dict[str, Any] | None = None,
|
||||
root: Any | None = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with the given field definitions.
|
||||
"""Create a Pydantic model with the given field definitions.
|
||||
|
||||
Attention:
|
||||
Please do not use outside of langchain packages. This API
|
||||
@@ -522,7 +522,7 @@ def create_model_v2(
|
||||
module_name: The name of the module where the model is defined.
|
||||
This is used by Pydantic to resolve any forward references.
|
||||
field_definitions: The field definitions for the model.
|
||||
root: Type for a root model (RootModel)
|
||||
root: Type for a root model (`RootModel`)
|
||||
|
||||
Returns:
|
||||
The created model.
|
||||
|
||||
@@ -57,7 +57,7 @@ def _content_blocks_equal_ignore_id(
|
||||
expected: Expected content to compare against (string or list of blocks).
|
||||
|
||||
Returns:
|
||||
True if content matches (excluding `id` fields), False otherwise.
|
||||
True if content matches (excluding `id` fields), `False` otherwise.
|
||||
|
||||
"""
|
||||
if isinstance(actual, str) or isinstance(expected, str):
|
||||
|
||||
@@ -1934,11 +1934,10 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
|
||||
|
||||
|
||||
def test_args_schema_explicitly_typed() -> None:
|
||||
"""This should test that one can type the args schema as a pydantic model.
|
||||
|
||||
Please note that this will test using pydantic 2 even though BaseTool
|
||||
is a pydantic 1 model!
|
||||
"""This should test that one can type the args schema as a Pydantic model.
|
||||
|
||||
Please note that this will test using pydantic 2 even though `BaseTool`
|
||||
is a Pydantic 1 model!
|
||||
"""
|
||||
|
||||
class Foo(BaseModel):
|
||||
@@ -1981,7 +1980,7 @@ def test_args_schema_explicitly_typed() -> None:
|
||||
|
||||
@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
|
||||
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
|
||||
"""This should test that one can type the args schema as a pydantic model."""
|
||||
"""This should test that one can type the args schema as a Pydantic model."""
|
||||
|
||||
def foo(a: int, b: str) -> str:
|
||||
"""Hahaha."""
|
||||
|
||||
@@ -7,9 +7,9 @@ from langchain_core.tracers.context import collect_runs
|
||||
|
||||
|
||||
def test_collect_runs() -> None:
|
||||
llm = FakeListLLM(responses=["hello"])
|
||||
model = FakeListLLM(responses=["hello"])
|
||||
with collect_runs() as cb:
|
||||
llm.invoke("hi")
|
||||
model.invoke("hi")
|
||||
assert cb.traced_runs
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert isinstance(cb.traced_runs[0].id, uuid.UUID)
|
||||
|
||||
@@ -124,7 +124,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
along with observations.
|
||||
|
||||
Returns:
|
||||
AgentFinish: Agent finish object.
|
||||
Agent finish object.
|
||||
|
||||
Raises:
|
||||
ValueError: If `early_stopping_method` is not supported.
|
||||
@@ -155,7 +155,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
BaseSingleActionAgent: Agent object.
|
||||
Agent object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -169,7 +169,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
"""Return dictionary representation of agent.
|
||||
|
||||
Returns:
|
||||
Dict: Dictionary representation of agent.
|
||||
Dictionary representation of agent.
|
||||
"""
|
||||
_dict = super().model_dump()
|
||||
try:
|
||||
@@ -233,7 +233,7 @@ class BaseMultiActionAgent(BaseModel):
|
||||
"""Get allowed tools.
|
||||
|
||||
Returns:
|
||||
list[str] | None: Allowed tools.
|
||||
Allowed tools.
|
||||
"""
|
||||
return None
|
||||
|
||||
@@ -297,7 +297,7 @@ class BaseMultiActionAgent(BaseModel):
|
||||
along with observations.
|
||||
|
||||
Returns:
|
||||
AgentFinish: Agent finish object.
|
||||
Agent finish object.
|
||||
|
||||
Raises:
|
||||
ValueError: If `early_stopping_method` is not supported.
|
||||
@@ -388,8 +388,7 @@ class MultiActionAgentOutputParser(
|
||||
text: Text to parse.
|
||||
|
||||
Returns:
|
||||
Union[List[AgentAction], AgentFinish]:
|
||||
List of agent actions or agent finish.
|
||||
List of agent actions or agent finish.
|
||||
"""
|
||||
|
||||
|
||||
@@ -812,7 +811,7 @@ class Agent(BaseSingleActionAgent):
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Full inputs for the LLMChain.
|
||||
Full inputs for the LLMChain.
|
||||
"""
|
||||
thoughts = self._construct_scratchpad(intermediate_steps)
|
||||
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
||||
@@ -834,7 +833,7 @@ class Agent(BaseSingleActionAgent):
|
||||
values: Values to validate.
|
||||
|
||||
Returns:
|
||||
Dict: Validated values.
|
||||
Validated values.
|
||||
|
||||
Raises:
|
||||
ValueError: If `agent_scratchpad` is not in prompt.input_variables
|
||||
@@ -875,7 +874,7 @@ class Agent(BaseSingleActionAgent):
|
||||
tools: Tools to use.
|
||||
|
||||
Returns:
|
||||
BasePromptTemplate: Prompt template.
|
||||
Prompt template.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@@ -910,7 +909,7 @@ class Agent(BaseSingleActionAgent):
|
||||
kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
Agent: Agent object.
|
||||
Agent object.
|
||||
"""
|
||||
cls._validate_tools(tools)
|
||||
llm_chain = LLMChain(
|
||||
@@ -942,7 +941,7 @@ class Agent(BaseSingleActionAgent):
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
AgentFinish: Agent finish object.
|
||||
Agent finish object.
|
||||
|
||||
Raises:
|
||||
ValueError: If `early_stopping_method` is not in ['force', 'generate'].
|
||||
@@ -1082,7 +1081,7 @@ class AgentExecutor(Chain):
|
||||
kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
AgentExecutor: Agent executor object.
|
||||
Agent executor object.
|
||||
"""
|
||||
return cls(
|
||||
agent=agent,
|
||||
@@ -1099,7 +1098,7 @@ class AgentExecutor(Chain):
|
||||
values: Values to validate.
|
||||
|
||||
Returns:
|
||||
Dict: Validated values.
|
||||
Validated values.
|
||||
|
||||
Raises:
|
||||
ValueError: If allowed tools are different than provided tools.
|
||||
@@ -1126,7 +1125,7 @@ class AgentExecutor(Chain):
|
||||
values: Values to validate.
|
||||
|
||||
Returns:
|
||||
Dict: Validated values.
|
||||
Validated values.
|
||||
"""
|
||||
agent = values.get("agent")
|
||||
if agent and isinstance(agent, Runnable):
|
||||
@@ -1209,7 +1208,7 @@ class AgentExecutor(Chain):
|
||||
async_: Whether to run async. (Ignored)
|
||||
|
||||
Returns:
|
||||
AgentExecutorIterator: Agent executor iterator object.
|
||||
Agent executor iterator object.
|
||||
"""
|
||||
return AgentExecutorIterator(
|
||||
self,
|
||||
@@ -1244,7 +1243,7 @@ class AgentExecutor(Chain):
|
||||
name: Name of tool.
|
||||
|
||||
Returns:
|
||||
BaseTool: Tool object.
|
||||
Tool object.
|
||||
"""
|
||||
return {tool.name: tool for tool in self.tools}[name]
|
||||
|
||||
@@ -1759,7 +1758,7 @@ class AgentExecutor(Chain):
|
||||
kwargs: Additional arguments.
|
||||
|
||||
Yields:
|
||||
AddableDict: Addable dictionary.
|
||||
Addable dictionary.
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
iterator = AgentExecutorIterator(
|
||||
@@ -1790,7 +1789,7 @@ class AgentExecutor(Chain):
|
||||
kwargs: Additional arguments.
|
||||
|
||||
Yields:
|
||||
AddableDict: Addable dictionary.
|
||||
Addable dictionary.
|
||||
"""
|
||||
config = ensure_config(config)
|
||||
iterator = AgentExecutorIterator(
|
||||
|
||||
@@ -58,7 +58,7 @@ def create_vectorstore_agent(
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||
model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||
|
||||
vector_store = InMemoryVectorStore.from_texts(
|
||||
[
|
||||
@@ -74,7 +74,7 @@ def create_vectorstore_agent(
|
||||
"Fetches information about pets.",
|
||||
)
|
||||
|
||||
agent = create_react_agent(llm, [tool])
|
||||
agent = create_react_agent(model, [tool])
|
||||
|
||||
for step in agent.stream(
|
||||
{"messages": [("human", "What are dogs known for?")]},
|
||||
@@ -156,7 +156,7 @@ def create_vectorstore_router_agent(
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||
model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||
|
||||
pet_vector_store = InMemoryVectorStore.from_texts(
|
||||
[
|
||||
@@ -187,7 +187,7 @@ def create_vectorstore_router_agent(
|
||||
),
|
||||
]
|
||||
|
||||
agent = create_react_agent(llm, tools)
|
||||
agent = create_react_agent(model, tools)
|
||||
|
||||
for step in agent.stream(
|
||||
{"messages": [("human", "Tell me about carrots.")]},
|
||||
|
||||
@@ -16,7 +16,7 @@ def format_log_to_str(
|
||||
Defaults to "Thought: ".
|
||||
|
||||
Returns:
|
||||
str: The scratchpad.
|
||||
The scratchpad.
|
||||
"""
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
|
||||
@@ -14,7 +14,7 @@ def format_log_to_messages(
|
||||
Defaults to "{observation}".
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The scratchpad.
|
||||
The scratchpad.
|
||||
"""
|
||||
thoughts: list[BaseMessage] = []
|
||||
for action, observation in intermediate_steps:
|
||||
|
||||
@@ -42,7 +42,7 @@ def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool:
|
||||
limit_to_domains: The allowed domains.
|
||||
|
||||
Returns:
|
||||
True if the URL is in the allowed domains, False otherwise.
|
||||
`True` if the URL is in the allowed domains, `False` otherwise.
|
||||
"""
|
||||
scheme, domain = _extract_scheme_and_domain(url)
|
||||
|
||||
@@ -143,7 +143,7 @@ try:
|
||||
description: Limit the number of results
|
||||
\"\"\"
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||
model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||
toolkit = RequestsToolkit(
|
||||
requests_wrapper=TextRequestsWrapper(headers={}), # no auth required
|
||||
allow_dangerous_requests=ALLOW_DANGEROUS_REQUESTS,
|
||||
@@ -152,7 +152,7 @@ try:
|
||||
|
||||
api_request_chain = (
|
||||
API_URL_PROMPT.partial(api_docs=api_spec)
|
||||
| llm.bind_tools(tools, tool_choice="any")
|
||||
| model.bind_tools(tools, tool_choice="any")
|
||||
)
|
||||
|
||||
class ChainState(TypedDict):
|
||||
@@ -169,7 +169,7 @@ try:
|
||||
return {"messages": [response]}
|
||||
|
||||
async def acall_model(state: ChainState, config: RunnableConfig):
|
||||
response = await llm.ainvoke(state["messages"], config)
|
||||
response = await model.ainvoke(state["messages"], config)
|
||||
return {"messages": [response]}
|
||||
|
||||
graph_builder = StateGraph(ChainState)
|
||||
|
||||
@@ -53,16 +53,16 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
input_variables=["page_content"], template="{page_content}"
|
||||
)
|
||||
document_variable_name = "context"
|
||||
llm = OpenAI()
|
||||
model = OpenAI()
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt = PromptTemplate.from_template("Summarize this content: {context}")
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
llm_chain = LLMChain(llm=model, prompt=prompt)
|
||||
# We now define how to combine these summaries
|
||||
reduce_prompt = PromptTemplate.from_template(
|
||||
"Combine these summaries: {context}"
|
||||
)
|
||||
reduce_llm_chain = LLMChain(llm=llm, prompt=reduce_prompt)
|
||||
reduce_llm_chain = LLMChain(llm=model, prompt=reduce_prompt)
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
@@ -79,7 +79,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
# which is specifically aimed at collapsing documents BEFORE
|
||||
# the final call.
|
||||
prompt = PromptTemplate.from_template("Collapse this content: {context}")
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
llm_chain = LLMChain(llm=model, prompt=prompt)
|
||||
collapse_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
|
||||
@@ -42,7 +42,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
from langchain_classic.output_parsers.regex import RegexParser
|
||||
|
||||
document_variable_name = "context"
|
||||
llm = OpenAI()
|
||||
model = OpenAI()
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
# The actual prompt will need to be a lot more complex, this is just
|
||||
@@ -61,7 +61,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
input_variables=["context"],
|
||||
output_parser=output_parser,
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
llm_chain = LLMChain(llm=model, prompt=prompt)
|
||||
chain = MapRerankDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
|
||||
@@ -171,11 +171,11 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
input_variables=["page_content"], template="{page_content}"
|
||||
)
|
||||
document_variable_name = "context"
|
||||
llm = OpenAI()
|
||||
model = OpenAI()
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt = PromptTemplate.from_template("Summarize this content: {context}")
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
llm_chain = LLMChain(llm=model, prompt=prompt)
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
@@ -188,7 +188,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
# which is specifically aimed at collapsing documents BEFORE
|
||||
# the final call.
|
||||
prompt = PromptTemplate.from_template("Collapse this content: {context}")
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
llm_chain = LLMChain(llm=model, prompt=prompt)
|
||||
collapse_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
|
||||
@@ -55,11 +55,11 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
input_variables=["page_content"], template="{page_content}"
|
||||
)
|
||||
document_variable_name = "context"
|
||||
llm = OpenAI()
|
||||
model = OpenAI()
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt = PromptTemplate.from_template("Summarize this content: {context}")
|
||||
initial_llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
initial_llm_chain = LLMChain(llm=model, prompt=prompt)
|
||||
initial_response_name = "prev_response"
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name` as well as `initial_response_name`
|
||||
@@ -67,7 +67,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
"Here's your first summary: {prev_response}. "
|
||||
"Now add to it based on the following context: {context}"
|
||||
)
|
||||
refine_llm_chain = LLMChain(llm=llm, prompt=prompt_refine)
|
||||
refine_llm_chain = LLMChain(llm=model, prompt=prompt_refine)
|
||||
chain = RefineDocumentsChain(
|
||||
initial_llm_chain=initial_llm_chain,
|
||||
refine_llm_chain=refine_llm_chain,
|
||||
|
||||
@@ -68,8 +68,8 @@ def create_stuff_documents_chain(
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[("system", "What are everyone's favorite colors:\n\n{context}")]
|
||||
)
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||
chain = create_stuff_documents_chain(llm, prompt)
|
||||
model = ChatOpenAI(model="gpt-3.5-turbo")
|
||||
chain = create_stuff_documents_chain(model, prompt)
|
||||
|
||||
docs = [
|
||||
Document(page_content="Jesse loves red but not yellow"),
|
||||
@@ -132,11 +132,11 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
input_variables=["page_content"], template="{page_content}"
|
||||
)
|
||||
document_variable_name = "context"
|
||||
llm = OpenAI()
|
||||
model = OpenAI()
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt = PromptTemplate.from_template("Summarize this content: {context}")
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
llm_chain = LLMChain(llm=model, prompt=prompt)
|
||||
chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
|
||||
@@ -58,7 +58,7 @@ class ConstitutionalChain(Chain):
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||
model = ChatOpenAI(model="gpt-4o-mini")
|
||||
|
||||
class Critique(TypedDict):
|
||||
"""Generate a critique, if needed."""
|
||||
@@ -86,9 +86,9 @@ class ConstitutionalChain(Chain):
|
||||
"Revision Request: {revision_request}"
|
||||
)
|
||||
|
||||
chain = llm | StrOutputParser()
|
||||
critique_chain = critique_prompt | llm.with_structured_output(Critique)
|
||||
revision_chain = revision_prompt | llm | StrOutputParser()
|
||||
chain = model | StrOutputParser()
|
||||
critique_chain = critique_prompt | model.with_structured_output(Critique)
|
||||
revision_chain = revision_prompt | model | StrOutputParser()
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
@@ -170,16 +170,16 @@ class ConstitutionalChain(Chain):
|
||||
from langchain_classic.chains.constitutional_ai.models \
|
||||
import ConstitutionalPrinciple
|
||||
|
||||
llm = OpenAI()
|
||||
llmodelm = OpenAI()
|
||||
|
||||
qa_prompt = PromptTemplate(
|
||||
template="Q: {question} A:",
|
||||
input_variables=["question"],
|
||||
)
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
qa_chain = LLMChain(llm=model, prompt=qa_prompt)
|
||||
|
||||
constitutional_chain = ConstitutionalChain.from_llm(
|
||||
llm=llm,
|
||||
llm=model,
|
||||
chain=qa_chain,
|
||||
constitutional_principles=[
|
||||
ConstitutionalPrinciple(
|
||||
|
||||
@@ -47,9 +47,9 @@ class ConversationChain(LLMChain):
|
||||
return store[session_id]
|
||||
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
|
||||
model = ChatOpenAI(model="gpt-3.5-turbo-0125")
|
||||
|
||||
chain = RunnableWithMessageHistory(llm, get_session_history)
|
||||
chain = RunnableWithMessageHistory(model, get_session_history)
|
||||
chain.invoke(
|
||||
"Hi I'm Bob.",
|
||||
config={"configurable": {"session_id": "1"}},
|
||||
@@ -85,9 +85,9 @@ class ConversationChain(LLMChain):
|
||||
return store[session_id]
|
||||
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
|
||||
model = ChatOpenAI(model="gpt-3.5-turbo-0125")
|
||||
|
||||
chain = RunnableWithMessageHistory(llm, get_session_history)
|
||||
chain = RunnableWithMessageHistory(model, get_session_history)
|
||||
chain.invoke(
|
||||
"Hi I'm Bob.",
|
||||
config={"configurable": {"session_id": "1"}},
|
||||
|
||||
@@ -283,7 +283,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
|
||||
retriever = ... # Your retriever
|
||||
|
||||
llm = ChatOpenAI()
|
||||
model = ChatOpenAI()
|
||||
|
||||
# Contextualize question
|
||||
contextualize_q_system_prompt = (
|
||||
@@ -301,7 +301,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
]
|
||||
)
|
||||
history_aware_retriever = create_history_aware_retriever(
|
||||
llm, retriever, contextualize_q_prompt
|
||||
model, retriever, contextualize_q_prompt
|
||||
)
|
||||
|
||||
# Answer question
|
||||
@@ -324,7 +324,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
# Below we use create_stuff_documents_chain to feed all retrieved context
|
||||
# into the LLM. Note that we can also use StuffDocumentsChain and other
|
||||
# instances of BaseCombineDocumentsChain.
|
||||
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
||||
question_answer_chain = create_stuff_documents_chain(model, qa_prompt)
|
||||
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
||||
|
||||
# Usage:
|
||||
@@ -371,8 +371,8 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
"Follow up question: {question}"
|
||||
)
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
llm = OpenAI()
|
||||
question_generator_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
model = OpenAI()
|
||||
question_generator_chain = LLMChain(llm=model, prompt=prompt)
|
||||
chain = ConversationalRetrievalChain(
|
||||
combine_docs_chain=combine_docs_chain,
|
||||
retriever=retriever,
|
||||
|
||||
@@ -38,10 +38,10 @@ def create_history_aware_retriever(
|
||||
from langchain_classic import hub
|
||||
|
||||
rephrase_prompt = hub.pull("langchain-ai/chat-langchain-rephrase")
|
||||
llm = ChatOpenAI()
|
||||
model = ChatOpenAI()
|
||||
retriever = ...
|
||||
chat_retriever_chain = create_history_aware_retriever(
|
||||
llm, retriever, rephrase_prompt
|
||||
model, retriever, rephrase_prompt
|
||||
)
|
||||
|
||||
chain.invoke({"input": "...", "chat_history": })
|
||||
|
||||
@@ -55,8 +55,8 @@ class LLMChain(Chain):
|
||||
|
||||
prompt_template = "Tell me a {adjective} joke"
|
||||
prompt = PromptTemplate(input_variables=["adjective"], template=prompt_template)
|
||||
llm = OpenAI()
|
||||
chain = prompt | llm | StrOutputParser()
|
||||
model = OpenAI()
|
||||
chain = prompt | model | StrOutputParser()
|
||||
|
||||
chain.invoke("your adjective here")
|
||||
```
|
||||
@@ -69,7 +69,7 @@ class LLMChain(Chain):
|
||||
|
||||
prompt_template = "Tell me a {adjective} joke"
|
||||
prompt = PromptTemplate(input_variables=["adjective"], template=prompt_template)
|
||||
llm = LLMChain(llm=OpenAI(), prompt=prompt)
|
||||
model = LLMChain(llm=OpenAI(), prompt=prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -80,8 +80,8 @@ class LLMCheckerChain(Chain):
|
||||
from langchain_community.llms import OpenAI
|
||||
from langchain_classic.chains import LLMCheckerChain
|
||||
|
||||
llm = OpenAI(temperature=0.7)
|
||||
checker_chain = LLMCheckerChain.from_llm(llm)
|
||||
model = OpenAI(temperature=0.7)
|
||||
checker_chain = LLMCheckerChain.from_llm(model)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -84,9 +84,9 @@ class LLMMathChain(Chain):
|
||||
)
|
||||
)
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||
model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||
tools = [calculator]
|
||||
llm_with_tools = llm.bind_tools(tools, tool_choice="any")
|
||||
model_with_tools = model.bind_tools(tools, tool_choice="any")
|
||||
|
||||
class ChainState(TypedDict):
|
||||
\"\"\"LangGraph state.\"\"\"
|
||||
@@ -95,11 +95,11 @@ class LLMMathChain(Chain):
|
||||
|
||||
async def acall_chain(state: ChainState, config: RunnableConfig):
|
||||
last_message = state["messages"][-1]
|
||||
response = await llm_with_tools.ainvoke(state["messages"], config)
|
||||
response = await model_with_tools.ainvoke(state["messages"], config)
|
||||
return {"messages": [response]}
|
||||
|
||||
async def acall_model(state: ChainState, config: RunnableConfig):
|
||||
response = await llm.ainvoke(state["messages"], config)
|
||||
response = await model.ainvoke(state["messages"], config)
|
||||
return {"messages": [response]}
|
||||
|
||||
graph_builder = StateGraph(ChainState)
|
||||
|
||||
@@ -83,8 +83,8 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
from langchain_community.llms import OpenAI
|
||||
from langchain_classic.chains import LLMSummarizationCheckerChain
|
||||
|
||||
llm = OpenAI(temperature=0.0)
|
||||
checker_chain = LLMSummarizationCheckerChain.from_llm(llm)
|
||||
model = OpenAI(temperature=0.0)
|
||||
checker_chain = LLMSummarizationCheckerChain.from_llm(model)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -84,8 +84,8 @@ class NatBotChain(Chain):
|
||||
"""Load with default LLMChain."""
|
||||
msg = (
|
||||
"This method is no longer implemented. Please use from_llm."
|
||||
"llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)"
|
||||
"For example, NatBotChain.from_llm(llm, objective)"
|
||||
"model = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)"
|
||||
"For example, NatBotChain.from_llm(model, objective)"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ def create_openai_fn_chain(
|
||||
fav_food: str | None = Field(None, description="The dog's favorite food")
|
||||
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4", temperature=0)
|
||||
model = ChatOpenAI(model="gpt-4", temperature=0)
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are a world class algorithm for recording entities."),
|
||||
@@ -115,7 +115,7 @@ def create_openai_fn_chain(
|
||||
("human", "Tip: Make sure to answer in the correct format"),
|
||||
]
|
||||
)
|
||||
chain = create_openai_fn_chain([RecordPerson, RecordDog], llm, prompt)
|
||||
chain = create_openai_fn_chain([RecordPerson, RecordDog], model, prompt)
|
||||
chain.run("Harry was a chubby brown beagle who loved chicken")
|
||||
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
|
||||
|
||||
@@ -191,7 +191,7 @@ def create_structured_output_chain(
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: str | None = Field(None, description="The dog's favorite food")
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0613", temperature=0)
|
||||
model = ChatOpenAI(model="gpt-3.5-turbo-0613", temperature=0)
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are a world class algorithm for extracting information in structured formats."),
|
||||
@@ -199,7 +199,7 @@ def create_structured_output_chain(
|
||||
("human", "Tip: Make sure to answer in the correct format"),
|
||||
]
|
||||
)
|
||||
chain = create_structured_output_chain(Dog, llm, prompt)
|
||||
chain = create_structured_output_chain(Dog, model, prompt)
|
||||
chain.run("Harry was a chubby brown beagle who loved chicken")
|
||||
# -> Dog(name="Harry", color="brown", fav_food="chicken")
|
||||
|
||||
|
||||
@@ -83,12 +83,12 @@ def create_citation_fuzzy_match_runnable(llm: BaseChatModel) -> Runnable:
|
||||
from langchain_classic.chains import create_citation_fuzzy_match_runnable
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||
model = ChatOpenAI(model="gpt-4o-mini")
|
||||
|
||||
context = "Alice has blue eyes. Bob has brown eyes. Charlie has green eyes."
|
||||
question = "What color are Bob's eyes?"
|
||||
|
||||
chain = create_citation_fuzzy_match_runnable(llm)
|
||||
chain = create_citation_fuzzy_match_runnable(model)
|
||||
chain.invoke({"question": question, "context": context})
|
||||
```
|
||||
|
||||
|
||||
@@ -73,8 +73,8 @@ Passage:
|
||||
# to see an up to date list of which models support
|
||||
# with_structured_output.
|
||||
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
|
||||
structured_llm = model.with_structured_output(Joke)
|
||||
structured_llm.invoke("Tell me a joke about cats.
|
||||
structured_model = model.with_structured_output(Joke)
|
||||
structured_model.invoke("Tell me a joke about cats.
|
||||
Make sure to call the Joke function.")
|
||||
"""
|
||||
),
|
||||
@@ -143,8 +143,8 @@ def create_extraction_chain(
|
||||
# to see an up to date list of which models support
|
||||
# with_structured_output.
|
||||
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
|
||||
structured_llm = model.with_structured_output(Joke)
|
||||
structured_llm.invoke("Tell me a joke about cats.
|
||||
structured_model = model.with_structured_output(Joke)
|
||||
structured_model.invoke("Tell me a joke about cats.
|
||||
Make sure to call the Joke function.")
|
||||
"""
|
||||
),
|
||||
@@ -155,10 +155,10 @@ def create_extraction_chain_pydantic(
|
||||
prompt: BasePromptTemplate | None = None,
|
||||
verbose: bool = False, # noqa: FBT001,FBT002
|
||||
) -> Chain:
|
||||
"""Creates a chain that extracts information from a passage using pydantic schema.
|
||||
"""Creates a chain that extracts information from a passage using Pydantic schema.
|
||||
|
||||
Args:
|
||||
pydantic_schema: The pydantic schema of the entities to extract.
|
||||
pydantic_schema: The Pydantic schema of the entities to extract.
|
||||
llm: The language model to use.
|
||||
prompt: The prompt to use for extraction.
|
||||
verbose: Whether to run in verbose mode. In verbose mode, some intermediate
|
||||
|
||||
@@ -330,7 +330,7 @@ def get_openapi_chain(
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
"Use the provided APIs to respond to this user query:\\n\\n{query}"
|
||||
)
|
||||
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0).bind_tools(tools)
|
||||
model = ChatOpenAI(model="gpt-4o-mini", temperature=0).bind_tools(tools)
|
||||
|
||||
def _execute_tool(message) -> Any:
|
||||
if tool_calls := message.tool_calls:
|
||||
@@ -341,7 +341,7 @@ def get_openapi_chain(
|
||||
else:
|
||||
return message.content
|
||||
|
||||
chain = prompt | llm | _execute_tool
|
||||
chain = prompt | model | _execute_tool
|
||||
```
|
||||
|
||||
```python
|
||||
@@ -394,7 +394,7 @@ def get_openapi_chain(
|
||||
msg = (
|
||||
"Must provide an LLM for this chain.For example,\n"
|
||||
"from langchain_openai import ChatOpenAI\n"
|
||||
"llm = ChatOpenAI()\n"
|
||||
"model = ChatOpenAI()\n"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
prompt = prompt or ChatPromptTemplate.from_template(
|
||||
|
||||
@@ -77,8 +77,8 @@ def create_tagging_chain(
|
||||
# to see an up to date list of which models support
|
||||
# with_structured_output.
|
||||
model = ChatAnthropic(model="claude-3-haiku-20240307", temperature=0)
|
||||
structured_llm = model.with_structured_output(Joke)
|
||||
structured_llm.invoke(
|
||||
structured_model = model.with_structured_output(Joke)
|
||||
structured_model.invoke(
|
||||
"Why did the cat cross the road? To get to the other "
|
||||
"side... and then lay down in the middle of it!"
|
||||
)
|
||||
@@ -93,7 +93,7 @@ def create_tagging_chain(
|
||||
kwargs: Additional keyword arguments to pass to the chain.
|
||||
|
||||
Returns:
|
||||
Chain (LLMChain) that can be used to extract information from a passage.
|
||||
Chain (`LLMChain`) that can be used to extract information from a passage.
|
||||
|
||||
"""
|
||||
function = _get_tagging_function(schema)
|
||||
@@ -130,10 +130,10 @@ def create_tagging_chain_pydantic(
|
||||
prompt: ChatPromptTemplate | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Chain:
|
||||
"""Create tagging chain from pydantic schema.
|
||||
"""Create tagging chain from Pydantic schema.
|
||||
|
||||
Create a chain that extracts information from a passage
|
||||
based on a pydantic schema.
|
||||
based on a Pydantic schema.
|
||||
|
||||
This function is deprecated. Please use `with_structured_output` instead.
|
||||
See example usage below:
|
||||
@@ -153,8 +153,8 @@ def create_tagging_chain_pydantic(
|
||||
# to see an up to date list of which models support
|
||||
# with_structured_output.
|
||||
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
|
||||
structured_llm = model.with_structured_output(Joke)
|
||||
structured_llm.invoke(
|
||||
structured_model = model.with_structured_output(Joke)
|
||||
structured_model.invoke(
|
||||
"Why did the cat cross the road? To get to the other "
|
||||
"side... and then lay down in the middle of it!"
|
||||
)
|
||||
@@ -163,13 +163,13 @@ def create_tagging_chain_pydantic(
|
||||
Read more here: https://python.langchain.com/docs/how_to/structured_output/
|
||||
|
||||
Args:
|
||||
pydantic_schema: The pydantic schema of the entities to extract.
|
||||
pydantic_schema: The Pydantic schema of the entities to extract.
|
||||
llm: The language model to use.
|
||||
prompt: The prompt template to use for the chain.
|
||||
kwargs: Additional keyword arguments to pass to the chain.
|
||||
|
||||
Returns:
|
||||
Chain (LLMChain) that can be used to extract information from a passage.
|
||||
Chain (`LLMChain`) that can be used to extract information from a passage.
|
||||
|
||||
"""
|
||||
if hasattr(pydantic_schema, "model_json_schema"):
|
||||
|
||||
@@ -42,8 +42,8 @@ If a property is not present and is not required in the function parameters, do
|
||||
# to see an up to date list of which models support
|
||||
# with_structured_output.
|
||||
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
|
||||
structured_llm = model.with_structured_output(Joke)
|
||||
structured_llm.invoke("Tell me a joke about cats.
|
||||
structured_model = model.with_structured_output(Joke)
|
||||
structured_model.invoke("Tell me a joke about cats.
|
||||
Make sure to call the Joke function.")
|
||||
"""
|
||||
),
|
||||
|
||||
@@ -48,7 +48,7 @@ def is_llm(llm: BaseLanguageModel) -> bool:
|
||||
llm: Language model to check.
|
||||
|
||||
Returns:
|
||||
True if the language model is a BaseLLM model, False otherwise.
|
||||
`True` if the language model is a BaseLLM model, `False` otherwise.
|
||||
"""
|
||||
return isinstance(llm, BaseLLM)
|
||||
|
||||
@@ -60,6 +60,6 @@ def is_chat_model(llm: BaseLanguageModel) -> bool:
|
||||
llm: Language model to check.
|
||||
|
||||
Returns:
|
||||
True if the language model is a BaseChatModel model, False otherwise.
|
||||
`True` if the language model is a BaseChatModel model, `False` otherwise.
|
||||
"""
|
||||
return isinstance(llm, BaseChatModel)
|
||||
|
||||
@@ -34,7 +34,7 @@ class QAGenerationChain(Chain):
|
||||
- Supports async and streaming;
|
||||
- Surfaces prompt and text splitter for easier customization;
|
||||
- Use of JsonOutputParser supports JSONPatch operations in streaming mode,
|
||||
as well as robustness to markdown.
|
||||
as well as robustness to markdown.
|
||||
|
||||
```python
|
||||
from langchain_classic.chains.qa_generation.prompt import (
|
||||
@@ -52,14 +52,14 @@ class QAGenerationChain(Chain):
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
llm = ChatOpenAI()
|
||||
model = ChatOpenAI()
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_overlap=500)
|
||||
split_text = RunnableLambda(lambda x: text_splitter.create_documents([x]))
|
||||
|
||||
chain = RunnableParallel(
|
||||
text=RunnablePassthrough(),
|
||||
questions=(
|
||||
split_text | RunnableEach(bound=prompt | llm | JsonOutputParser())
|
||||
split_text | RunnableEach(bound=prompt | model | JsonOutputParser())
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
@@ -112,7 +112,7 @@ class QueryTransformer(Transformer):
|
||||
args: The arguments passed to the function.
|
||||
|
||||
Returns:
|
||||
FilterDirective: The filter directive.
|
||||
The filter directive.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function is a comparator and the first arg is not in the
|
||||
|
||||
@@ -46,9 +46,11 @@ def create_retrieval_chain(
|
||||
from langchain_classic import hub
|
||||
|
||||
retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
|
||||
llm = ChatOpenAI()
|
||||
model = ChatOpenAI()
|
||||
retriever = ...
|
||||
combine_docs_chain = create_stuff_documents_chain(llm, retrieval_qa_chat_prompt)
|
||||
combine_docs_chain = create_stuff_documents_chain(
|
||||
model, retrieval_qa_chat_prompt
|
||||
)
|
||||
retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain)
|
||||
|
||||
retrieval_chain.invoke({"input": "..."})
|
||||
|
||||
@@ -237,7 +237,7 @@ class RetrievalQA(BaseRetrievalQA):
|
||||
|
||||
|
||||
retriever = ... # Your retriever
|
||||
llm = ChatOpenAI()
|
||||
model = ChatOpenAI()
|
||||
|
||||
system_prompt = (
|
||||
"Use the given context to answer the question. "
|
||||
@@ -251,7 +251,7 @@ class RetrievalQA(BaseRetrievalQA):
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
question_answer_chain = create_stuff_documents_chain(llm, prompt)
|
||||
question_answer_chain = create_stuff_documents_chain(model, prompt)
|
||||
chain = create_retrieval_chain(retriever, question_answer_chain)
|
||||
|
||||
chain.invoke({"input": query})
|
||||
|
||||
@@ -48,7 +48,7 @@ class LLMRouterChain(RouterChain):
|
||||
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||
model = ChatOpenAI(model="gpt-4o-mini")
|
||||
|
||||
prompt_1 = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
@@ -63,8 +63,8 @@ class LLMRouterChain(RouterChain):
|
||||
]
|
||||
)
|
||||
|
||||
chain_1 = prompt_1 | llm | StrOutputParser()
|
||||
chain_2 = prompt_2 | llm | StrOutputParser()
|
||||
chain_1 = prompt_1 | model | StrOutputParser()
|
||||
chain_2 = prompt_2 | model | StrOutputParser()
|
||||
|
||||
route_system = "Route the user's query to either the animal "
|
||||
"or vegetable expert."
|
||||
@@ -83,7 +83,7 @@ class LLMRouterChain(RouterChain):
|
||||
|
||||
route_chain = (
|
||||
route_prompt
|
||||
| llm.with_structured_output(RouteQuery)
|
||||
| model.with_structured_output(RouteQuery)
|
||||
| itemgetter("destination")
|
||||
)
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class MultiPromptChain(MultiRouteChain):
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||
model = ChatOpenAI(model="gpt-4o-mini")
|
||||
|
||||
# Define the prompts we will route to
|
||||
prompt_1 = ChatPromptTemplate.from_messages(
|
||||
@@ -68,8 +68,8 @@ class MultiPromptChain(MultiRouteChain):
|
||||
# Construct the chains we will route to. These format the input query
|
||||
# into the respective prompt, run it through a chat model, and cast
|
||||
# the result to a string.
|
||||
chain_1 = prompt_1 | llm | StrOutputParser()
|
||||
chain_2 = prompt_2 | llm | StrOutputParser()
|
||||
chain_1 = prompt_1 | model | StrOutputParser()
|
||||
chain_2 = prompt_2 | model | StrOutputParser()
|
||||
|
||||
|
||||
# Next: define the chain that selects which branch to route to.
|
||||
@@ -92,7 +92,7 @@ class MultiPromptChain(MultiRouteChain):
|
||||
destination: Literal["animal", "vegetable"]
|
||||
|
||||
|
||||
route_chain = route_prompt | llm.with_structured_output(RouteQuery)
|
||||
route_chain = route_prompt | model.with_structured_output(RouteQuery)
|
||||
|
||||
|
||||
# For LangGraph, we will define the state of the graph to hold the query,
|
||||
|
||||
@@ -117,7 +117,7 @@ class MultiRetrievalQAChain(MultiRouteChain):
|
||||
"default LLMs on behalf of users."
|
||||
"You can provide a conversation LLM like so:\n"
|
||||
"from langchain_openai import ChatOpenAI\n"
|
||||
"llm = ChatOpenAI()"
|
||||
"model = ChatOpenAI()"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
_default_chain = ConversationChain(
|
||||
|
||||
@@ -76,8 +76,8 @@ def create_sql_query_chain(
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
|
||||
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
|
||||
chain = create_sql_query_chain(llm, db)
|
||||
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
|
||||
chain = create_sql_query_chain(model, db)
|
||||
response = chain.invoke({"question": "How many employees are there"})
|
||||
```
|
||||
|
||||
|
||||
@@ -57,8 +57,8 @@ from pydantic import BaseModel
|
||||
# to see an up to date list of which models support
|
||||
# with_structured_output.
|
||||
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
|
||||
structured_llm = model.with_structured_output(Joke)
|
||||
structured_llm.invoke("Tell me a joke about cats.
|
||||
structured_model = model.with_structured_output(Joke)
|
||||
structured_model.invoke("Tell me a joke about cats.
|
||||
Make sure to call the Joke function.")
|
||||
"""
|
||||
),
|
||||
@@ -127,9 +127,9 @@ def create_openai_fn_runnable(
|
||||
fav_food: str | None = Field(None, description="The dog's favorite food")
|
||||
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4", temperature=0)
|
||||
structured_llm = create_openai_fn_runnable([RecordPerson, RecordDog], llm)
|
||||
structured_llm.invoke("Harry was a chubby brown beagle who loved chicken)
|
||||
model = ChatOpenAI(model="gpt-4", temperature=0)
|
||||
structured_model = create_openai_fn_runnable([RecordPerson, RecordDog], model)
|
||||
structured_model.invoke("Harry was a chubby brown beagle who loved chicken)
|
||||
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
|
||||
|
||||
```
|
||||
@@ -176,8 +176,8 @@ def create_openai_fn_runnable(
|
||||
# to see an up to date list of which models support
|
||||
# with_structured_output.
|
||||
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
|
||||
structured_llm = model.with_structured_output(Joke)
|
||||
structured_llm.invoke("Tell me a joke about cats.
|
||||
structured_model = model.with_structured_output(Joke)
|
||||
structured_model.invoke("Tell me a joke about cats.
|
||||
Make sure to call the Joke function.")
|
||||
"""
|
||||
),
|
||||
@@ -250,21 +250,21 @@ def create_structured_output_runnable(
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: str | None = Field(None, description="The dog's favorite food")
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an extraction algorithm. Please extract every possible instance"),
|
||||
('human', '{input}')
|
||||
]
|
||||
)
|
||||
structured_llm = create_structured_output_runnable(
|
||||
structured_model = create_structured_output_runnable(
|
||||
RecordDog,
|
||||
llm,
|
||||
model,
|
||||
mode="openai-tools",
|
||||
enforce_function_usage=True,
|
||||
return_single=True
|
||||
)
|
||||
structured_llm.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
|
||||
structured_model.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
|
||||
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
|
||||
```
|
||||
|
||||
@@ -303,15 +303,15 @@ def create_structured_output_runnable(
|
||||
}
|
||||
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_llm = create_structured_output_runnable(
|
||||
model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_model = create_structured_output_runnable(
|
||||
dog_schema,
|
||||
llm,
|
||||
model,
|
||||
mode="openai-tools",
|
||||
enforce_function_usage=True,
|
||||
return_single=True
|
||||
)
|
||||
structured_llm.invoke("Harry was a chubby brown beagle who loved chicken")
|
||||
structured_model.invoke("Harry was a chubby brown beagle who loved chicken")
|
||||
# -> {'name': 'Harry', 'color': 'brown', 'fav_food': 'chicken'}
|
||||
```
|
||||
|
||||
@@ -330,9 +330,9 @@ def create_structured_output_runnable(
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: str | None = Field(None, description="The dog's favorite food")
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_llm = create_structured_output_runnable(Dog, llm, mode="openai-functions")
|
||||
structured_llm.invoke("Harry was a chubby brown beagle who loved chicken")
|
||||
model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_model = create_structured_output_runnable(Dog, model, mode="openai-functions")
|
||||
structured_model.invoke("Harry was a chubby brown beagle who loved chicken")
|
||||
# -> Dog(name="Harry", color="brown", fav_food="chicken")
|
||||
```
|
||||
|
||||
@@ -352,13 +352,13 @@ def create_structured_output_runnable(
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: str | None = Field(None, description="The dog's favorite food")
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_llm = create_structured_output_runnable(Dog, llm, mode="openai-functions")
|
||||
model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_model = create_structured_output_runnable(Dog, model, mode="openai-functions")
|
||||
system = '''Extract information about any dogs mentioned in the user input.'''
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[("system", system), ("human", "{input}"),]
|
||||
)
|
||||
chain = prompt | structured_llm
|
||||
chain = prompt | structured_model
|
||||
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
|
||||
# -> Dog(name="Harry", color="brown", fav_food="chicken")
|
||||
```
|
||||
@@ -379,8 +379,8 @@ def create_structured_output_runnable(
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: str | None = Field(None, description="The dog's favorite food")
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_llm = create_structured_output_runnable(Dog, llm, mode="openai-json")
|
||||
model = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_model = create_structured_output_runnable(Dog, model, mode="openai-json")
|
||||
system = '''You are a world class assistant for extracting information in structured JSON formats. \
|
||||
|
||||
Extract a valid JSON blob from the user input that matches the following JSON Schema:
|
||||
@@ -389,7 +389,7 @@ def create_structured_output_runnable(
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[("system", system), ("human", "{input}"),]
|
||||
)
|
||||
chain = prompt | structured_llm
|
||||
chain = prompt | structured_model
|
||||
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
|
||||
|
||||
```
|
||||
|
||||
@@ -113,10 +113,10 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
|
||||
\"\"\"Very helpful answers to geography questions.\"\"\"
|
||||
return f"{country}? IDK - We may never know {question}."
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
|
||||
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
|
||||
agent = initialize_agent(
|
||||
tools=[geography_answers],
|
||||
llm=llm,
|
||||
llm=model,
|
||||
agent=AgentType.OPENAI_FUNCTIONS,
|
||||
return_intermediate_steps=True,
|
||||
)
|
||||
@@ -125,7 +125,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
|
||||
response = agent(question)
|
||||
|
||||
eval_chain = TrajectoryEvalChain.from_llm(
|
||||
llm=llm, agent_tools=[geography_answers], return_reasoning=True
|
||||
llm=model, agent_tools=[geography_answers], return_reasoning=True
|
||||
)
|
||||
|
||||
result = eval_chain.evaluate_agent_trajectory(
|
||||
@@ -165,7 +165,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
|
||||
"""Get the description of the agent tools.
|
||||
|
||||
Returns:
|
||||
str: The description of the agent tools.
|
||||
The description of the agent tools.
|
||||
"""
|
||||
if self.agent_tools is None:
|
||||
return ""
|
||||
@@ -184,10 +184,10 @@ Description: {tool.description}"""
|
||||
"""Get the agent trajectory as a formatted string.
|
||||
|
||||
Args:
|
||||
steps (Union[str, List[Tuple[AgentAction, str]]]): The agent trajectory.
|
||||
steps: The agent trajectory.
|
||||
|
||||
Returns:
|
||||
str: The formatted agent trajectory.
|
||||
The formatted agent trajectory.
|
||||
"""
|
||||
if isinstance(steps, str):
|
||||
return steps
|
||||
@@ -240,7 +240,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
TrajectoryEvalChain: The TrajectoryEvalChain object.
|
||||
The `TrajectoryEvalChain` object.
|
||||
"""
|
||||
if not isinstance(llm, BaseChatModel):
|
||||
msg = "Only chat models supported by the current trajectory eval"
|
||||
@@ -259,7 +259,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
"""Get the input keys for the chain.
|
||||
|
||||
Returns:
|
||||
List[str]: The input keys.
|
||||
The input keys.
|
||||
"""
|
||||
return ["question", "agent_trajectory", "answer", "reference"]
|
||||
|
||||
@@ -268,7 +268,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
"""Get the output keys for the chain.
|
||||
|
||||
Returns:
|
||||
List[str]: The output keys.
|
||||
The output keys.
|
||||
"""
|
||||
return ["score", "reasoning"]
|
||||
|
||||
@@ -289,7 +289,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
run_manager: The callback manager for the chain run.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The output values of the chain.
|
||||
The output values of the chain.
|
||||
"""
|
||||
chain_input = {**inputs}
|
||||
if self.agent_tools:
|
||||
@@ -313,7 +313,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
run_manager: The callback manager for the chain run.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The output values of the chain.
|
||||
The output values of the chain.
|
||||
"""
|
||||
chain_input = {**inputs}
|
||||
if self.agent_tools:
|
||||
|
||||
@@ -165,10 +165,10 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
|
||||
Example:
|
||||
>>> from langchain_community.chat_models import ChatOpenAI
|
||||
>>> from langchain_classic.evaluation.comparison import PairwiseStringEvalChain
|
||||
>>> llm = ChatOpenAI(
|
||||
>>> model = ChatOpenAI(
|
||||
... temperature=0, model_name="gpt-4", model_kwargs={"random_seed": 42}
|
||||
... )
|
||||
>>> chain = PairwiseStringEvalChain.from_llm(llm=llm)
|
||||
>>> chain = PairwiseStringEvalChain.from_llm(llm=model)
|
||||
>>> result = chain.evaluate_string_pairs(
|
||||
... input = "What is the chemical formula for water?",
|
||||
... prediction = "H2O",
|
||||
@@ -207,7 +207,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
|
||||
"""Return whether the chain requires a reference.
|
||||
|
||||
Returns:
|
||||
True if the chain requires a reference, False otherwise.
|
||||
`True` if the chain requires a reference, `False` otherwise.
|
||||
|
||||
"""
|
||||
return False
|
||||
@@ -217,7 +217,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
|
||||
"""Return whether the chain requires an input.
|
||||
|
||||
Returns:
|
||||
bool: True if the chain requires an input, False otherwise.
|
||||
`True` if the chain requires an input, `False` otherwise.
|
||||
|
||||
"""
|
||||
return True
|
||||
@@ -227,7 +227,7 @@ class PairwiseStringEvalChain(PairwiseStringEvaluator, LLMEvalChain, LLMChain):
|
||||
"""Return the warning to show when reference is ignored.
|
||||
|
||||
Returns:
|
||||
str: The warning to show when reference is ignored.
|
||||
The warning to show when reference is ignored.
|
||||
|
||||
"""
|
||||
return (
|
||||
@@ -425,7 +425,7 @@ class LabeledPairwiseStringEvalChain(PairwiseStringEvalChain):
|
||||
"""Return whether the chain requires a reference.
|
||||
|
||||
Returns:
|
||||
bool: True if the chain requires a reference, False otherwise.
|
||||
`True` if the chain requires a reference, `False` otherwise.
|
||||
|
||||
"""
|
||||
return True
|
||||
@@ -442,18 +442,18 @@ class LabeledPairwiseStringEvalChain(PairwiseStringEvalChain):
|
||||
"""Initialize the LabeledPairwiseStringEvalChain from an LLM.
|
||||
|
||||
Args:
|
||||
llm (BaseLanguageModel): The LLM to use.
|
||||
prompt (PromptTemplate, optional): The prompt to use.
|
||||
criteria (Union[CRITERIA_TYPE, str], optional): The criteria to use.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
llm: The LLM to use.
|
||||
prompt: The prompt to use.
|
||||
criteria: The criteria to use.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
LabeledPairwiseStringEvalChain: The initialized LabeledPairwiseStringEvalChain.
|
||||
The initialized `LabeledPairwiseStringEvalChain`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input variables are not as expected.
|
||||
|
||||
""" # noqa: E501
|
||||
"""
|
||||
expected_input_vars = {
|
||||
"prediction",
|
||||
"prediction_b",
|
||||
|
||||
@@ -15,9 +15,9 @@ Using a predefined criterion:
|
||||
>>> from langchain_community.llms import OpenAI
|
||||
>>> from langchain_classic.evaluation.criteria import CriteriaEvalChain
|
||||
|
||||
>>> llm = OpenAI()
|
||||
>>> model = OpenAI()
|
||||
>>> criteria = "conciseness"
|
||||
>>> chain = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria)
|
||||
>>> chain = CriteriaEvalChain.from_llm(llm=model, criteria=criteria)
|
||||
>>> chain.evaluate_strings(
|
||||
prediction="The answer is 42.",
|
||||
reference="42",
|
||||
@@ -29,7 +29,7 @@ Using a custom criterion:
|
||||
>>> from langchain_community.llms import OpenAI
|
||||
>>> from langchain_classic.evaluation.criteria import LabeledCriteriaEvalChain
|
||||
|
||||
>>> llm = OpenAI()
|
||||
>>> model = OpenAI()
|
||||
>>> criteria = {
|
||||
"hallucination": (
|
||||
"Does this submission contain information"
|
||||
@@ -37,7 +37,7 @@ Using a custom criterion:
|
||||
),
|
||||
}
|
||||
>>> chain = LabeledCriteriaEvalChain.from_llm(
|
||||
llm=llm,
|
||||
llm=model,
|
||||
criteria=criteria,
|
||||
)
|
||||
>>> chain.evaluate_strings(
|
||||
|
||||
@@ -190,9 +190,9 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
--------
|
||||
>>> from langchain_anthropic import ChatAnthropic
|
||||
>>> from langchain_classic.evaluation.criteria import CriteriaEvalChain
|
||||
>>> llm = ChatAnthropic(temperature=0)
|
||||
>>> model = ChatAnthropic(temperature=0)
|
||||
>>> criteria = {"my-custom-criterion": "Is the submission the most amazing ever?"}
|
||||
>>> evaluator = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria)
|
||||
>>> evaluator = CriteriaEvalChain.from_llm(llm=model, criteria=criteria)
|
||||
>>> evaluator.evaluate_strings(
|
||||
... prediction="Imagine an ice cream flavor for the color aquamarine",
|
||||
... input="Tell me an idea",
|
||||
@@ -205,10 +205,10 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
|
||||
>>> from langchain_openai import ChatOpenAI
|
||||
>>> from langchain_classic.evaluation.criteria import LabeledCriteriaEvalChain
|
||||
>>> llm = ChatOpenAI(model="gpt-4", temperature=0)
|
||||
>>> model = ChatOpenAI(model="gpt-4", temperature=0)
|
||||
>>> criteria = "correctness"
|
||||
>>> evaluator = LabeledCriteriaEvalChain.from_llm(
|
||||
... llm=llm,
|
||||
... llm=model,
|
||||
... criteria=criteria,
|
||||
... )
|
||||
>>> evaluator.evaluate_strings(
|
||||
@@ -347,7 +347,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
--------
|
||||
>>> from langchain_openai import OpenAI
|
||||
>>> from langchain_classic.evaluation.criteria import LabeledCriteriaEvalChain
|
||||
>>> llm = OpenAI()
|
||||
>>> model = OpenAI()
|
||||
>>> criteria = {
|
||||
"hallucination": (
|
||||
"Does this submission contain information"
|
||||
@@ -355,7 +355,7 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
),
|
||||
}
|
||||
>>> chain = LabeledCriteriaEvalChain.from_llm(
|
||||
llm=llm,
|
||||
llm=model,
|
||||
criteria=criteria,
|
||||
)
|
||||
"""
|
||||
@@ -433,9 +433,9 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
Examples:
|
||||
>>> from langchain_openai import OpenAI
|
||||
>>> from langchain_classic.evaluation.criteria import CriteriaEvalChain
|
||||
>>> llm = OpenAI()
|
||||
>>> model = OpenAI()
|
||||
>>> criteria = "conciseness"
|
||||
>>> chain = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria)
|
||||
>>> chain = CriteriaEvalChain.from_llm(llm=model, criteria=criteria)
|
||||
>>> chain.evaluate_strings(
|
||||
prediction="The answer is 42.",
|
||||
reference="42",
|
||||
@@ -485,9 +485,9 @@ class CriteriaEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
Examples:
|
||||
>>> from langchain_openai import OpenAI
|
||||
>>> from langchain_classic.evaluation.criteria import CriteriaEvalChain
|
||||
>>> llm = OpenAI()
|
||||
>>> model = OpenAI()
|
||||
>>> criteria = "conciseness"
|
||||
>>> chain = CriteriaEvalChain.from_llm(llm=llm, criteria=criteria)
|
||||
>>> chain = CriteriaEvalChain.from_llm(llm=model, criteria=criteria)
|
||||
>>> await chain.aevaluate_strings(
|
||||
prediction="The answer is 42.",
|
||||
reference="42",
|
||||
@@ -569,7 +569,7 @@ class LabeledCriteriaEvalChain(CriteriaEvalChain):
|
||||
--------
|
||||
>>> from langchain_openai import OpenAI
|
||||
>>> from langchain_classic.evaluation.criteria import LabeledCriteriaEvalChain
|
||||
>>> llm = OpenAI()
|
||||
>>> model = OpenAI()
|
||||
>>> criteria = {
|
||||
"hallucination": (
|
||||
"Does this submission contain information"
|
||||
@@ -577,7 +577,7 @@ class LabeledCriteriaEvalChain(CriteriaEvalChain):
|
||||
),
|
||||
}
|
||||
>>> chain = LabeledCriteriaEvalChain.from_llm(
|
||||
llm=llm,
|
||||
llm=model,
|
||||
criteria=criteria,
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -51,7 +51,7 @@ def _embedding_factory() -> Embeddings:
|
||||
"""Create an Embeddings object.
|
||||
|
||||
Returns:
|
||||
Embeddings: The created Embeddings object.
|
||||
The created `Embeddings` object.
|
||||
"""
|
||||
# Here for backwards compatibility.
|
||||
# Generally, we do not want to be seeing imports from langchain community
|
||||
@@ -94,9 +94,8 @@ class _EmbeddingDistanceChainMixin(Chain):
|
||||
"""Shared functionality for embedding distance evaluators.
|
||||
|
||||
Attributes:
|
||||
embeddings (Embeddings): The embedding objects to vectorize the outputs.
|
||||
distance_metric (EmbeddingDistance): The distance metric to use
|
||||
for comparing the embeddings.
|
||||
embeddings: The embedding objects to vectorize the outputs.
|
||||
distance_metric: The distance metric to use for comparing the embeddings.
|
||||
"""
|
||||
|
||||
embeddings: Embeddings = Field(default_factory=_embedding_factory)
|
||||
@@ -107,10 +106,10 @@ class _EmbeddingDistanceChainMixin(Chain):
|
||||
"""Validate that the TikTok library is installed.
|
||||
|
||||
Args:
|
||||
values (Dict[str, Any]): The values to validate.
|
||||
values: The values to validate.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The validated values.
|
||||
The validated values.
|
||||
"""
|
||||
embeddings = values.get("embeddings")
|
||||
types_ = []
|
||||
@@ -159,7 +158,7 @@ class _EmbeddingDistanceChainMixin(Chain):
|
||||
"""Return the output keys of the chain.
|
||||
|
||||
Returns:
|
||||
List[str]: The output keys.
|
||||
The output keys.
|
||||
"""
|
||||
return ["score"]
|
||||
|
||||
@@ -173,10 +172,10 @@ class _EmbeddingDistanceChainMixin(Chain):
|
||||
"""Get the metric function for the given metric name.
|
||||
|
||||
Args:
|
||||
metric (EmbeddingDistance): The metric name.
|
||||
metric: The metric name.
|
||||
|
||||
Returns:
|
||||
Any: The metric function.
|
||||
The metric function.
|
||||
"""
|
||||
metrics = {
|
||||
EmbeddingDistance.COSINE: self._cosine_distance,
|
||||
@@ -334,7 +333,7 @@ class _EmbeddingDistanceChainMixin(Chain):
|
||||
vectors (np.ndarray): The input vectors.
|
||||
|
||||
Returns:
|
||||
float: The computed score.
|
||||
The computed score.
|
||||
"""
|
||||
metric = self._get_metric(self.distance_metric)
|
||||
if _check_numpy() and isinstance(vectors, _import_numpy().ndarray):
|
||||
@@ -362,7 +361,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
"""Return whether the chain requires a reference.
|
||||
|
||||
Returns:
|
||||
bool: True if a reference is required, False otherwise.
|
||||
True if a reference is required, `False` otherwise.
|
||||
"""
|
||||
return True
|
||||
|
||||
@@ -376,7 +375,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
"""Return the input keys of the chain.
|
||||
|
||||
Returns:
|
||||
List[str]: The input keys.
|
||||
The input keys.
|
||||
"""
|
||||
return ["prediction", "reference"]
|
||||
|
||||
@@ -393,7 +392,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
run_manager: The callback manager.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The computed score.
|
||||
The computed score.
|
||||
"""
|
||||
vectors = self.embeddings.embed_documents(
|
||||
[inputs["prediction"], inputs["reference"]],
|
||||
@@ -413,12 +412,11 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
"""Asynchronously compute the score for a prediction and reference.
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): The input data.
|
||||
run_manager (AsyncCallbackManagerForChainRun, optional):
|
||||
The callback manager.
|
||||
inputs: The input data.
|
||||
run_manager: The callback manager.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The computed score.
|
||||
The computed score.
|
||||
"""
|
||||
vectors = await self.embeddings.aembed_documents(
|
||||
[
|
||||
@@ -523,7 +521,7 @@ class PairwiseEmbeddingDistanceEvalChain(
|
||||
"""Return the input keys of the chain.
|
||||
|
||||
Returns:
|
||||
List[str]: The input keys.
|
||||
The input keys.
|
||||
"""
|
||||
return ["prediction", "prediction_b"]
|
||||
|
||||
@@ -541,12 +539,11 @@ class PairwiseEmbeddingDistanceEvalChain(
|
||||
"""Compute the score for two predictions.
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): The input data.
|
||||
run_manager (CallbackManagerForChainRun, optional):
|
||||
The callback manager.
|
||||
inputs: The input data.
|
||||
run_manager: The callback manager.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The computed score.
|
||||
The computed score.
|
||||
"""
|
||||
vectors = self.embeddings.embed_documents(
|
||||
[
|
||||
@@ -569,12 +566,11 @@ class PairwiseEmbeddingDistanceEvalChain(
|
||||
"""Asynchronously compute the score for two predictions.
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): The input data.
|
||||
run_manager (AsyncCallbackManagerForChainRun, optional):
|
||||
The callback manager.
|
||||
inputs: The input data.
|
||||
run_manager: The callback manager.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The computed score.
|
||||
The computed score.
|
||||
"""
|
||||
vectors = await self.embeddings.aembed_documents(
|
||||
[
|
||||
|
||||
@@ -61,7 +61,7 @@ class ExactMatchStringEvaluator(StringEvaluator):
|
||||
"""Get the input keys.
|
||||
|
||||
Returns:
|
||||
List[str]: The input keys.
|
||||
The input keys.
|
||||
"""
|
||||
return ["reference", "prediction"]
|
||||
|
||||
@@ -70,7 +70,7 @@ class ExactMatchStringEvaluator(StringEvaluator):
|
||||
"""Get the evaluation name.
|
||||
|
||||
Returns:
|
||||
str: The evaluation name.
|
||||
The evaluation name.
|
||||
"""
|
||||
return "exact_match"
|
||||
|
||||
|
||||
@@ -71,10 +71,11 @@ class JsonValidityEvaluator(StringEvaluator):
|
||||
**kwargs: Additional keyword arguments (not used).
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the evaluation score. The score is 1 if
|
||||
the prediction is valid JSON, and 0 otherwise.
|
||||
If the prediction is not valid JSON, the dictionary also contains
|
||||
a "reasoning" field with the error message.
|
||||
A dictionary containing the evaluation score. The score is `1` if
|
||||
the prediction is valid JSON, and `0` otherwise.
|
||||
|
||||
If the prediction is not valid JSON, the dictionary also contains
|
||||
a `reasoning` field with the error message.
|
||||
|
||||
"""
|
||||
try:
|
||||
|
||||
@@ -50,7 +50,7 @@ class JsonEditDistanceEvaluator(StringEvaluator):
|
||||
|
||||
Raises:
|
||||
ImportError: If the `rapidfuzz` package is not installed and no
|
||||
`string_distance` function is provided.
|
||||
`string_distance` function is provided.
|
||||
"""
|
||||
super().__init__()
|
||||
if string_distance is not None:
|
||||
|
||||
@@ -113,17 +113,16 @@ class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||
"""Load QA Eval Chain from LLM.
|
||||
|
||||
Args:
|
||||
llm (BaseLanguageModel): the base language model to use.
|
||||
llm: The base language model to use.
|
||||
prompt: A prompt template containing the input_variables:
|
||||
`'input'`, `'answer'` and `'result'` that will be used as the prompt
|
||||
for evaluation.
|
||||
|
||||
prompt (PromptTemplate): A prompt template containing the input_variables:
|
||||
'input', 'answer' and 'result' that will be used as the prompt
|
||||
for evaluation.
|
||||
Defaults to PROMPT.
|
||||
|
||||
**kwargs: additional keyword arguments.
|
||||
Defaults to `PROMPT`.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
QAEvalChain: the loaded QA eval chain.
|
||||
The loaded QA eval chain.
|
||||
"""
|
||||
prompt = prompt or PROMPT
|
||||
expected_input_vars = {"query", "answer", "result"}
|
||||
@@ -264,17 +263,16 @@ class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain):
|
||||
"""Load QA Eval Chain from LLM.
|
||||
|
||||
Args:
|
||||
llm (BaseLanguageModel): the base language model to use.
|
||||
llm: The base language model to use.
|
||||
prompt: A prompt template containing the `input_variables`:
|
||||
`'query'`, `'context'` and `'result'` that will be used as the prompt
|
||||
for evaluation.
|
||||
|
||||
prompt (PromptTemplate): A prompt template containing the input_variables:
|
||||
'query', 'context' and 'result' that will be used as the prompt
|
||||
for evaluation.
|
||||
Defaults to PROMPT.
|
||||
|
||||
**kwargs: additional keyword arguments.
|
||||
Defaults to `PROMPT`.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
ContextQAEvalChain: the loaded QA eval chain.
|
||||
The loaded QA eval chain.
|
||||
"""
|
||||
prompt = prompt or CONTEXT_PROMPT
|
||||
cls._validate_input_vars(prompt)
|
||||
|
||||
@@ -53,7 +53,7 @@ class RegexMatchStringEvaluator(StringEvaluator):
|
||||
"""Get the input keys.
|
||||
|
||||
Returns:
|
||||
List[str]: The input keys.
|
||||
The input keys.
|
||||
"""
|
||||
return ["reference", "prediction"]
|
||||
|
||||
@@ -62,7 +62,7 @@ class RegexMatchStringEvaluator(StringEvaluator):
|
||||
"""Get the evaluation name.
|
||||
|
||||
Returns:
|
||||
str: The evaluation name.
|
||||
The evaluation name.
|
||||
"""
|
||||
return "regex_match"
|
||||
|
||||
|
||||
@@ -114,8 +114,8 @@ class _EvalArgsMixin:
|
||||
"""Check if the evaluation arguments are valid.
|
||||
|
||||
Args:
|
||||
reference (str | None, optional): The reference label.
|
||||
input_ (str | None, optional): The input string.
|
||||
reference: The reference label.
|
||||
input_: The input string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the evaluator requires an input string but none is provided,
|
||||
@@ -162,17 +162,17 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
"""Evaluate Chain or LLM output, based on optional input and label.
|
||||
|
||||
Args:
|
||||
prediction (str): The LLM or chain prediction to evaluate.
|
||||
reference (str | None, optional): The reference label to evaluate against.
|
||||
input (str | None, optional): The input to consider during evaluation.
|
||||
kwargs: Additional keyword arguments, including callbacks, tags, etc.
|
||||
prediction: The LLM or chain prediction to evaluate.
|
||||
reference: The reference label to evaluate against.
|
||||
input: The input to consider during evaluation.
|
||||
**kwargs: Additional keyword arguments, including callbacks, tags, etc.
|
||||
|
||||
Returns:
|
||||
dict: The evaluation results containing the score or value.
|
||||
It is recommended that the dictionary contain the following keys:
|
||||
- score: the score of the evaluation, if applicable.
|
||||
- value: the string value of the evaluation, if applicable.
|
||||
- reasoning: the reasoning for the evaluation, if applicable.
|
||||
The evaluation results containing the score or value.
|
||||
It is recommended that the dictionary contain the following keys:
|
||||
- score: the score of the evaluation, if applicable.
|
||||
- value: the string value of the evaluation, if applicable.
|
||||
- reasoning: the reasoning for the evaluation, if applicable.
|
||||
"""
|
||||
|
||||
async def _aevaluate_strings(
|
||||
@@ -186,17 +186,17 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
"""Asynchronously evaluate Chain or LLM output, based on optional input and label.
|
||||
|
||||
Args:
|
||||
prediction (str): The LLM or chain prediction to evaluate.
|
||||
reference (str | None, optional): The reference label to evaluate against.
|
||||
input (str | None, optional): The input to consider during evaluation.
|
||||
kwargs: Additional keyword arguments, including callbacks, tags, etc.
|
||||
prediction: The LLM or chain prediction to evaluate.
|
||||
reference: The reference label to evaluate against.
|
||||
input: The input to consider during evaluation.
|
||||
**kwargs: Additional keyword arguments, including callbacks, tags, etc.
|
||||
|
||||
Returns:
|
||||
dict: The evaluation results containing the score or value.
|
||||
It is recommended that the dictionary contain the following keys:
|
||||
- score: the score of the evaluation, if applicable.
|
||||
- value: the string value of the evaluation, if applicable.
|
||||
- reasoning: the reasoning for the evaluation, if applicable.
|
||||
The evaluation results containing the score or value.
|
||||
It is recommended that the dictionary contain the following keys:
|
||||
- score: the score of the evaluation, if applicable.
|
||||
- value: the string value of the evaluation, if applicable.
|
||||
- reasoning: the reasoning for the evaluation, if applicable.
|
||||
""" # noqa: E501
|
||||
return await run_in_executor(
|
||||
None,
|
||||
@@ -218,13 +218,13 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
"""Evaluate Chain or LLM output, based on optional input and label.
|
||||
|
||||
Args:
|
||||
prediction (str): The LLM or chain prediction to evaluate.
|
||||
reference (str | None, optional): The reference label to evaluate against.
|
||||
input (str | None, optional): The input to consider during evaluation.
|
||||
kwargs: Additional keyword arguments, including callbacks, tags, etc.
|
||||
prediction: The LLM or chain prediction to evaluate.
|
||||
reference: The reference label to evaluate against.
|
||||
input: The input to consider during evaluation.
|
||||
**kwargs: Additional keyword arguments, including callbacks, tags, etc.
|
||||
|
||||
Returns:
|
||||
dict: The evaluation results containing the score or value.
|
||||
The evaluation results containing the score or value.
|
||||
"""
|
||||
self._check_evaluation_args(reference=reference, input_=input)
|
||||
return self._evaluate_strings(
|
||||
@@ -245,13 +245,13 @@ class StringEvaluator(_EvalArgsMixin, ABC):
|
||||
"""Asynchronously evaluate Chain or LLM output, based on optional input and label.
|
||||
|
||||
Args:
|
||||
prediction (str): The LLM or chain prediction to evaluate.
|
||||
reference (str | None, optional): The reference label to evaluate against.
|
||||
input (str | None, optional): The input to consider during evaluation.
|
||||
kwargs: Additional keyword arguments, including callbacks, tags, etc.
|
||||
prediction: The LLM or chain prediction to evaluate.
|
||||
reference: The reference label to evaluate against.
|
||||
input: The input to consider during evaluation.
|
||||
**kwargs: Additional keyword arguments, including callbacks, tags, etc.
|
||||
|
||||
Returns:
|
||||
dict: The evaluation results containing the score or value.
|
||||
The evaluation results containing the score or value.
|
||||
""" # noqa: E501
|
||||
self._check_evaluation_args(reference=reference, input_=input)
|
||||
return await self._aevaluate_strings(
|
||||
@@ -278,14 +278,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
"""Evaluate the output string pairs.
|
||||
|
||||
Args:
|
||||
prediction (str): The output string from the first model.
|
||||
prediction_b (str): The output string from the second model.
|
||||
reference (str | None, optional): The expected output / reference string.
|
||||
input (str | None, optional): The input string.
|
||||
kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
|
||||
prediction: The output string from the first model.
|
||||
prediction_b: The output string from the second model.
|
||||
reference: The expected output / reference string.
|
||||
input: The input string.
|
||||
**kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the preference, scores, and/or other information.
|
||||
A dictionary containing the preference, scores, and/or other information.
|
||||
""" # noqa: E501
|
||||
|
||||
async def _aevaluate_string_pairs(
|
||||
@@ -300,14 +300,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
"""Asynchronously evaluate the output string pairs.
|
||||
|
||||
Args:
|
||||
prediction (str): The output string from the first model.
|
||||
prediction_b (str): The output string from the second model.
|
||||
reference (str | None, optional): The expected output / reference string.
|
||||
input (str | None, optional): The input string.
|
||||
kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
|
||||
prediction: The output string from the first model.
|
||||
prediction_b: The output string from the second model.
|
||||
reference: The expected output / reference string.
|
||||
input: The input string.
|
||||
**kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the preference, scores, and/or other information.
|
||||
A dictionary containing the preference, scores, and/or other information.
|
||||
""" # noqa: E501
|
||||
return await run_in_executor(
|
||||
None,
|
||||
@@ -331,14 +331,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
"""Evaluate the output string pairs.
|
||||
|
||||
Args:
|
||||
prediction (str): The output string from the first model.
|
||||
prediction_b (str): The output string from the second model.
|
||||
reference (str | None, optional): The expected output / reference string.
|
||||
input (str | None, optional): The input string.
|
||||
kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
|
||||
prediction: The output string from the first model.
|
||||
prediction_b: The output string from the second model.
|
||||
reference: The expected output / reference string.
|
||||
input: The input string.
|
||||
**kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the preference, scores, and/or other information.
|
||||
A dictionary containing the preference, scores, and/or other information.
|
||||
""" # noqa: E501
|
||||
self._check_evaluation_args(reference=reference, input_=input)
|
||||
return self._evaluate_string_pairs(
|
||||
@@ -361,14 +361,14 @@ class PairwiseStringEvaluator(_EvalArgsMixin, ABC):
|
||||
"""Asynchronously evaluate the output string pairs.
|
||||
|
||||
Args:
|
||||
prediction (str): The output string from the first model.
|
||||
prediction_b (str): The output string from the second model.
|
||||
reference (str | None, optional): The expected output / reference string.
|
||||
input (str | None, optional): The input string.
|
||||
kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
|
||||
prediction: The output string from the first model.
|
||||
prediction_b: The output string from the second model.
|
||||
reference: The expected output / reference string.
|
||||
input: The input string.
|
||||
**kwargs: Additional keyword arguments, such as callbacks and optional reference strings.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the preference, scores, and/or other information.
|
||||
A dictionary containing the preference, scores, and/or other information.
|
||||
""" # noqa: E501
|
||||
self._check_evaluation_args(reference=reference, input_=input)
|
||||
return await self._aevaluate_string_pairs(
|
||||
|
||||
@@ -7,8 +7,8 @@ criteria and or a reference answer.
|
||||
Example:
|
||||
>>> from langchain_community.chat_models import ChatOpenAI
|
||||
>>> from langchain_classic.evaluation.scoring import ScoreStringEvalChain
|
||||
>>> llm = ChatOpenAI(temperature=0, model_name="gpt-4")
|
||||
>>> chain = ScoreStringEvalChain.from_llm(llm=llm)
|
||||
>>> model = ChatOpenAI(temperature=0, model_name="gpt-4")
|
||||
>>> chain = ScoreStringEvalChain.from_llm(llm=model)
|
||||
>>> result = chain.evaluate_strings(
|
||||
... input="What is the chemical formula for water?",
|
||||
... prediction="H2O",
|
||||
|
||||
@@ -56,10 +56,10 @@ def resolve_criteria(
|
||||
"""Resolve the criteria for the pairwise evaluator.
|
||||
|
||||
Args:
|
||||
criteria (Union[CRITERIA_TYPE, str], optional): The criteria to use.
|
||||
criteria: The criteria to use.
|
||||
|
||||
Returns:
|
||||
dict: The resolved criteria.
|
||||
The resolved criteria.
|
||||
|
||||
"""
|
||||
if criteria is None:
|
||||
@@ -156,8 +156,8 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
Example:
|
||||
>>> from langchain_community.chat_models import ChatOpenAI
|
||||
>>> from langchain_classic.evaluation.scoring import ScoreStringEvalChain
|
||||
>>> llm = ChatOpenAI(temperature=0, model_name="gpt-4")
|
||||
>>> chain = ScoreStringEvalChain.from_llm(llm=llm)
|
||||
>>> model = ChatOpenAI(temperature=0, model_name="gpt-4")
|
||||
>>> chain = ScoreStringEvalChain.from_llm(llm=model)
|
||||
>>> result = chain.evaluate_strings(
|
||||
... input="What is the chemical formula for water?",
|
||||
... prediction="H2O",
|
||||
@@ -196,7 +196,7 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
"""Return whether the chain requires a reference.
|
||||
|
||||
Returns:
|
||||
bool: True if the chain requires a reference, False otherwise.
|
||||
`True` if the chain requires a reference, `False` otherwise.
|
||||
|
||||
"""
|
||||
return False
|
||||
@@ -206,7 +206,7 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
"""Return whether the chain requires an input.
|
||||
|
||||
Returns:
|
||||
bool: True if the chain requires an input, False otherwise.
|
||||
`True` if the chain requires an input, `False` otherwise.
|
||||
|
||||
"""
|
||||
return True
|
||||
@@ -227,7 +227,7 @@ class ScoreStringEvalChain(StringEvaluator, LLMEvalChain, LLMChain):
|
||||
"""Return the warning to show when reference is ignored.
|
||||
|
||||
Returns:
|
||||
str: The warning to show when reference is ignored.
|
||||
The warning to show when reference is ignored.
|
||||
|
||||
"""
|
||||
return (
|
||||
@@ -424,7 +424,7 @@ class LabeledScoreStringEvalChain(ScoreStringEvalChain):
|
||||
"""Return whether the chain requires a reference.
|
||||
|
||||
Returns:
|
||||
bool: True if the chain requires a reference, False otherwise.
|
||||
`True` if the chain requires a reference, `False` otherwise.
|
||||
|
||||
"""
|
||||
return True
|
||||
@@ -442,14 +442,14 @@ class LabeledScoreStringEvalChain(ScoreStringEvalChain):
|
||||
"""Initialize the LabeledScoreStringEvalChain from an LLM.
|
||||
|
||||
Args:
|
||||
llm (BaseLanguageModel): The LLM to use.
|
||||
prompt (PromptTemplate, optional): The prompt to use.
|
||||
criteria (Union[CRITERIA_TYPE, str], optional): The criteria to use.
|
||||
normalize_by (float, optional): The value to normalize the score by.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
llm: The LLM to use.
|
||||
prompt: The prompt to use.
|
||||
criteria: The criteria to use.
|
||||
normalize_by: The value to normalize the score by.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
LabeledScoreStringEvalChain: The initialized LabeledScoreStringEvalChain.
|
||||
The initialized LabeledScoreStringEvalChain.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input variables are not as expected.
|
||||
|
||||
@@ -22,10 +22,10 @@ def _load_rapidfuzz() -> Any:
|
||||
"""Load the RapidFuzz library.
|
||||
|
||||
Raises:
|
||||
ImportError: If the rapidfuzz library is not installed.
|
||||
`ImportError`: If the rapidfuzz library is not installed.
|
||||
|
||||
Returns:
|
||||
Any: The rapidfuzz.distance module.
|
||||
The `rapidfuzz.distance` module.
|
||||
"""
|
||||
try:
|
||||
import rapidfuzz
|
||||
@@ -42,12 +42,12 @@ class StringDistance(str, Enum):
|
||||
"""Distance metric to use.
|
||||
|
||||
Attributes:
|
||||
DAMERAU_LEVENSHTEIN: The Damerau-Levenshtein distance.
|
||||
LEVENSHTEIN: The Levenshtein distance.
|
||||
JARO: The Jaro distance.
|
||||
JARO_WINKLER: The Jaro-Winkler distance.
|
||||
HAMMING: The Hamming distance.
|
||||
INDEL: The Indel distance.
|
||||
`DAMERAU_LEVENSHTEIN`: The Damerau-Levenshtein distance.
|
||||
`LEVENSHTEIN`: The Levenshtein distance.
|
||||
`JARO`: The Jaro distance.
|
||||
`JARO_WINKLER`: The Jaro-Winkler distance.
|
||||
`HAMMING`: The Hamming distance.
|
||||
`INDEL`: The Indel distance.
|
||||
"""
|
||||
|
||||
DAMERAU_LEVENSHTEIN = "damerau_levenshtein"
|
||||
@@ -63,7 +63,7 @@ class _RapidFuzzChainMixin(Chain):
|
||||
|
||||
distance: StringDistance = Field(default=StringDistance.JARO_WINKLER)
|
||||
normalize_score: bool = Field(default=True)
|
||||
"""Whether to normalize the score to a value between 0 and 1.
|
||||
"""Whether to normalize the score to a value between `0` and `1`.
|
||||
Applies only to the Levenshtein and Damerau-Levenshtein distances."""
|
||||
|
||||
@pre_init
|
||||
@@ -71,10 +71,10 @@ class _RapidFuzzChainMixin(Chain):
|
||||
"""Validate that the rapidfuzz library is installed.
|
||||
|
||||
Args:
|
||||
values (Dict[str, Any]): The input values.
|
||||
values: The input values.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The validated values.
|
||||
The validated values.
|
||||
"""
|
||||
_load_rapidfuzz()
|
||||
return values
|
||||
@@ -84,7 +84,7 @@ class _RapidFuzzChainMixin(Chain):
|
||||
"""Get the output keys.
|
||||
|
||||
Returns:
|
||||
List[str]: The output keys.
|
||||
The output keys.
|
||||
"""
|
||||
return ["score"]
|
||||
|
||||
@@ -92,10 +92,10 @@ class _RapidFuzzChainMixin(Chain):
|
||||
"""Prepare the output dictionary.
|
||||
|
||||
Args:
|
||||
result (Dict[str, Any]): The evaluation results.
|
||||
result: The evaluation results.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The prepared output dictionary.
|
||||
The prepared output dictionary.
|
||||
"""
|
||||
result = {"score": result["score"]}
|
||||
if RUN_KEY in result:
|
||||
@@ -111,7 +111,7 @@ class _RapidFuzzChainMixin(Chain):
|
||||
normalize_score: Whether to normalize the score.
|
||||
|
||||
Returns:
|
||||
Callable: The distance metric function.
|
||||
The distance metric function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the distance metric is invalid.
|
||||
@@ -142,7 +142,7 @@ class _RapidFuzzChainMixin(Chain):
|
||||
"""Get the distance metric function.
|
||||
|
||||
Returns:
|
||||
Callable: The distance metric function.
|
||||
The distance metric function.
|
||||
"""
|
||||
return _RapidFuzzChainMixin._get_metric(
|
||||
self.distance,
|
||||
@@ -199,7 +199,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
||||
"""Get the input keys.
|
||||
|
||||
Returns:
|
||||
List[str]: The input keys.
|
||||
The input keys.
|
||||
"""
|
||||
return ["reference", "prediction"]
|
||||
|
||||
@@ -208,7 +208,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
||||
"""Get the evaluation name.
|
||||
|
||||
Returns:
|
||||
str: The evaluation name.
|
||||
The evaluation name.
|
||||
"""
|
||||
return f"{self.distance.value}_distance"
|
||||
|
||||
@@ -330,7 +330,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
||||
"""Get the input keys.
|
||||
|
||||
Returns:
|
||||
List[str]: The input keys.
|
||||
The input keys.
|
||||
"""
|
||||
return ["prediction", "prediction_b"]
|
||||
|
||||
@@ -339,7 +339,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
||||
"""Get the evaluation name.
|
||||
|
||||
Returns:
|
||||
str: The evaluation name.
|
||||
The evaluation name.
|
||||
"""
|
||||
return f"pairwise_{self.distance.value}_distance"
|
||||
|
||||
@@ -352,12 +352,11 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
||||
"""Compute the string distance between two predictions.
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): The input values.
|
||||
run_manager (CallbackManagerForChainRun , optional):
|
||||
The callback manager.
|
||||
inputs: The input values.
|
||||
run_manager: The callback manager.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The evaluation results containing the score.
|
||||
The evaluation results containing the score.
|
||||
"""
|
||||
return {
|
||||
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
|
||||
@@ -372,12 +371,11 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
||||
"""Asynchronously compute the string distance between two predictions.
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): The input values.
|
||||
run_manager (AsyncCallbackManagerForChainRun , optional):
|
||||
The callback manager.
|
||||
inputs: The input values.
|
||||
run_manager: The callback manager.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The evaluation results containing the score.
|
||||
The evaluation results containing the score.
|
||||
"""
|
||||
return {
|
||||
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
|
||||
|
||||
@@ -55,7 +55,7 @@ class VectorStoreIndexWrapper(BaseModel):
|
||||
"Please provide an llm to use for querying the vectorstore.\n"
|
||||
"For example,\n"
|
||||
"from langchain_openai import OpenAI\n"
|
||||
"llm = OpenAI(temperature=0)"
|
||||
"model = OpenAI(temperature=0)"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
retriever_kwargs = retriever_kwargs or {}
|
||||
@@ -90,7 +90,7 @@ class VectorStoreIndexWrapper(BaseModel):
|
||||
"Please provide an llm to use for querying the vectorstore.\n"
|
||||
"For example,\n"
|
||||
"from langchain_openai import OpenAI\n"
|
||||
"llm = OpenAI(temperature=0)"
|
||||
"model = OpenAI(temperature=0)"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
retriever_kwargs = retriever_kwargs or {}
|
||||
@@ -125,7 +125,7 @@ class VectorStoreIndexWrapper(BaseModel):
|
||||
"Please provide an llm to use for querying the vectorstore.\n"
|
||||
"For example,\n"
|
||||
"from langchain_openai import OpenAI\n"
|
||||
"llm = OpenAI(temperature=0)"
|
||||
"model = OpenAI(temperature=0)"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
retriever_kwargs = retriever_kwargs or {}
|
||||
@@ -160,7 +160,7 @@ class VectorStoreIndexWrapper(BaseModel):
|
||||
"Please provide an llm to use for querying the vectorstore.\n"
|
||||
"For example,\n"
|
||||
"from langchain_openai import OpenAI\n"
|
||||
"llm = OpenAI(temperature=0)"
|
||||
"model = OpenAI(temperature=0)"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
retriever_kwargs = retriever_kwargs or {}
|
||||
|
||||
@@ -21,15 +21,15 @@ class ModelLaboratory:
|
||||
Args:
|
||||
chains: A sequence of chains to experiment with.
|
||||
Each chain must have exactly one input and one output variable.
|
||||
names (list[str] | None): Optional list of names corresponding to each
|
||||
chain. If provided, its length must match the number of chains.
|
||||
names: Optional list of names corresponding to each chain.
|
||||
If provided, its length must match the number of chains.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If any chain is not an instance of `Chain`.
|
||||
ValueError: If a chain does not have exactly one input variable.
|
||||
ValueError: If a chain does not have exactly one output variable.
|
||||
ValueError: If the length of `names` does not match the number of chains.
|
||||
`ValueError`: If any chain is not an instance of `Chain`.
|
||||
`ValueError`: If a chain does not have exactly one input variable.
|
||||
`ValueError`: If a chain does not have exactly one output variable.
|
||||
`ValueError`: If the length of `names` does not match the number of chains.
|
||||
"""
|
||||
for chain in chains:
|
||||
if not isinstance(chain, Chain):
|
||||
@@ -72,7 +72,7 @@ class ModelLaboratory:
|
||||
If provided, the prompt must contain exactly one input variable.
|
||||
|
||||
Returns:
|
||||
ModelLaboratory: An instance of `ModelLaboratory` initialized with LLMs.
|
||||
An instance of `ModelLaboratory` initialized with LLMs.
|
||||
"""
|
||||
if prompt is None:
|
||||
prompt = PromptTemplate(input_variables=["_input"], template="{_input}")
|
||||
|
||||
@@ -96,7 +96,7 @@ class StructuredOutputParser(BaseOutputParser[dict[str, Any]]):
|
||||
# ```
|
||||
|
||||
Args:
|
||||
only_json (bool): If `True`, only the json in the Markdown code snippet
|
||||
only_json: If `True`, only the json in the Markdown code snippet
|
||||
will be returned, without the introducing text. Defaults to `False`.
|
||||
"""
|
||||
schema_str = "\n".join(
|
||||
|
||||
@@ -16,10 +16,10 @@ T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class YamlOutputParser(BaseOutputParser[T]):
|
||||
"""Parse YAML output using a pydantic model."""
|
||||
"""Parse YAML output using a Pydantic model."""
|
||||
|
||||
pydantic_object: type[T]
|
||||
"""The pydantic model to parse."""
|
||||
"""The Pydantic model to parse."""
|
||||
pattern: re.Pattern = re.compile(
|
||||
r"^```(?:ya?ml)?(?P<yaml>[^`]*)",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
|
||||
@@ -299,8 +299,8 @@ class EnsembleRetriever(BaseRetriever):
|
||||
doc_lists: A list of rank lists, where each rank list contains unique items.
|
||||
|
||||
Returns:
|
||||
list: The final aggregated list of items sorted by their weighted RRF
|
||||
scores in descending order.
|
||||
The final aggregated list of items sorted by their weighted RRF
|
||||
scores in descending order.
|
||||
"""
|
||||
if len(doc_lists) != len(self.weights):
|
||||
msg = "Number of rank lists must be equal to the number of weights."
|
||||
|
||||
@@ -22,8 +22,8 @@ from langchain_classic.smith import RunEvalConfig, run_on_dataset
|
||||
# Chains may have memory. Passing in a constructor function lets the
|
||||
# evaluation framework avoid cross-contamination between runs.
|
||||
def construct_chain():
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
chain = LLMChain.from_string(llm, "What's the answer to {your_input_key}")
|
||||
model = ChatOpenAI(temperature=0)
|
||||
chain = LLMChain.from_string(model, "What's the answer to {your_input_key}")
|
||||
return chain
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ from langchain_classic.smith import EvaluatorType, RunEvalConfig, run_on_dataset
|
||||
|
||||
|
||||
def construct_chain():
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
chain = LLMChain.from_string(llm, "What's the answer to {your_input_key}")
|
||||
model = ChatOpenAI(temperature=0)
|
||||
chain = LLMChain.from_string(model, "What's the answer to {your_input_key}")
|
||||
return chain
|
||||
|
||||
|
||||
|
||||
@@ -982,8 +982,7 @@ def _run_llm_or_chain(
|
||||
input_mapper: Optional function to map the input to the expected format.
|
||||
|
||||
Returns:
|
||||
Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
The outputs of the model or chain.
|
||||
The outputs of the model or chain.
|
||||
"""
|
||||
chain_or_llm = (
|
||||
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
|
||||
@@ -1396,9 +1395,9 @@ async def arun_on_dataset(
|
||||
# Chains may have memory. Passing in a constructor function lets the
|
||||
# evaluation framework avoid cross-contamination between runs.
|
||||
def construct_chain():
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
model = ChatOpenAI(temperature=0)
|
||||
chain = LLMChain.from_string(
|
||||
llm,
|
||||
model,
|
||||
"What's the answer to {your_input_key}"
|
||||
)
|
||||
return chain
|
||||
@@ -1571,9 +1570,9 @@ def run_on_dataset(
|
||||
# Chains may have memory. Passing in a constructor function lets the
|
||||
# evaluation framework avoid cross-contamination between runs.
|
||||
def construct_chain():
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
model = ChatOpenAI(temperature=0)
|
||||
chain = LLMChain.from_string(
|
||||
llm,
|
||||
model,
|
||||
"What's the answer to {your_input_key}"
|
||||
)
|
||||
return chain
|
||||
|
||||
@@ -47,14 +47,15 @@ class LocalFileStore(ByteStore):
|
||||
"""Implement the BaseStore interface for the local file system.
|
||||
|
||||
Args:
|
||||
root_path (Union[str, Path]): The root path of the file store. All keys are
|
||||
interpreted as paths relative to this root.
|
||||
chmod_file: (optional, defaults to `None`) If specified, sets permissions
|
||||
for newly created files, overriding the current `umask` if needed.
|
||||
chmod_dir: (optional, defaults to `None`) If specified, sets permissions
|
||||
for newly created dirs, overriding the current `umask` if needed.
|
||||
update_atime: (optional, defaults to `False`) If `True`, updates the
|
||||
filesystem access time (but not the modified time) when a file is read.
|
||||
root_path: The root path of the file store. All keys are interpreted as
|
||||
paths relative to this root.
|
||||
chmod_file: If specified, sets permissions for newly created files,
|
||||
overriding the current `umask` if needed.
|
||||
chmod_dir: If specified, sets permissions for newly created dirs,
|
||||
overriding the current `umask` if needed.
|
||||
update_atime: If `True`, updates the filesystem access time
|
||||
(but not the modified time) when a file is read.
|
||||
|
||||
This allows MRU/LRU cache policies to be implemented for filesystems
|
||||
where access time updates are disabled.
|
||||
"""
|
||||
@@ -67,10 +68,10 @@ class LocalFileStore(ByteStore):
|
||||
"""Get the full path for a given key relative to the root path.
|
||||
|
||||
Args:
|
||||
key (str): The key relative to the root path.
|
||||
key: The key relative to the root path.
|
||||
|
||||
Returns:
|
||||
Path: The full path for the given key.
|
||||
The full path for the given key.
|
||||
"""
|
||||
if not re.match(r"^[a-zA-Z0-9_.\-/]+$", key):
|
||||
msg = f"Invalid characters in key: {key}"
|
||||
@@ -148,10 +149,8 @@ class LocalFileStore(ByteStore):
|
||||
"""Delete the given keys and their associated values.
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): A sequence of keys to delete.
|
||||
keys: A sequence of keys to delete.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for key in keys:
|
||||
full_path = self._get_full_path(key)
|
||||
@@ -162,10 +161,10 @@ class LocalFileStore(ByteStore):
|
||||
"""Get an iterator over keys that match the given prefix.
|
||||
|
||||
Args:
|
||||
prefix (str | None): The prefix to match.
|
||||
prefix: The prefix to match.
|
||||
|
||||
Returns:
|
||||
Iterator[str]: An iterator over keys that match the given prefix.
|
||||
Yields:
|
||||
Keys that match the given prefix.
|
||||
"""
|
||||
prefix_path = self._get_full_path(prefix) if prefix else self.root_path
|
||||
for file in prefix_path.rglob("*"):
|
||||
|
||||
@@ -27,10 +27,10 @@ def __getattr__(name: str) -> Any:
|
||||
"""Dynamically retrieve attributes from the updated module path.
|
||||
|
||||
Args:
|
||||
name (str): The name of the attribute to import.
|
||||
name: The name of the attribute to import.
|
||||
|
||||
Returns:
|
||||
Any: The resolved attribute from the updated path.
|
||||
The resolved attribute from the updated path.
|
||||
"""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
@@ -35,10 +35,10 @@ def __getattr__(name: str) -> Any:
|
||||
at runtime and forward them to their new locations.
|
||||
|
||||
Args:
|
||||
name (str): The name of the attribute to import.
|
||||
name: The name of the attribute to import.
|
||||
|
||||
Returns:
|
||||
Any: The resolved attribute from the appropriate updated module.
|
||||
The resolved attribute from the appropriate updated module.
|
||||
"""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
@@ -33,10 +33,10 @@ def __getattr__(name: str) -> Any:
|
||||
at runtime and forward them to their new locations.
|
||||
|
||||
Args:
|
||||
name (str): The name of the attribute to import.
|
||||
name: The name of the attribute to import.
|
||||
|
||||
Returns:
|
||||
Any: The resolved attribute from the appropriate updated module.
|
||||
The resolved attribute from the appropriate updated module.
|
||||
"""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
@@ -108,15 +108,15 @@ def test_no_more_changes_to_proxy_community() -> None:
|
||||
|
||||
|
||||
def extract_deprecated_lookup(file_path: str) -> dict[str, Any] | None:
|
||||
"""Detect and extracts the value of a dictionary named DEPRECATED_LOOKUP.
|
||||
"""Detect and extracts the value of a dictionary named `DEPRECATED_LOOKUP`.
|
||||
|
||||
This variable is located in the global namespace of a Python file.
|
||||
|
||||
Args:
|
||||
file_path (str): The path to the Python file.
|
||||
file_path: The path to the Python file.
|
||||
|
||||
Returns:
|
||||
dict or None: The value of DEPRECATED_LOOKUP if it exists, None otherwise.
|
||||
The value of `DEPRECATED_LOOKUP` if it exists, `None` otherwise.
|
||||
"""
|
||||
tree = ast.parse(Path(file_path).read_text(encoding="utf-8"), filename=file_path)
|
||||
|
||||
@@ -136,10 +136,10 @@ def _dict_from_ast(node: ast.Dict) -> dict[str, str]:
|
||||
"""Convert an AST dict node to a Python dictionary, assuming str to str format.
|
||||
|
||||
Args:
|
||||
node (ast.Dict): The AST node representing a dictionary.
|
||||
node: The AST node representing a dictionary.
|
||||
|
||||
Returns:
|
||||
dict: The corresponding Python dictionary.
|
||||
The corresponding Python dictionary.
|
||||
"""
|
||||
result: dict[str, str] = {}
|
||||
for key, value in zip(node.keys, node.values, strict=False):
|
||||
@@ -153,10 +153,10 @@ def _literal_eval_str(node: ast.AST) -> str:
|
||||
"""Evaluate an AST literal node to its corresponding string value.
|
||||
|
||||
Args:
|
||||
node (ast.AST): The AST node representing a literal value.
|
||||
node: The AST node representing a literal value.
|
||||
|
||||
Returns:
|
||||
str: The corresponding string value.
|
||||
The corresponding string value.
|
||||
"""
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, str):
|
||||
return node.value
|
||||
|
||||
@@ -13,9 +13,6 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
@@ -47,11 +44,10 @@ from langchain.agents.structured_output import (
|
||||
ToolStrategy,
|
||||
)
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain.tools import ToolNode
|
||||
from langchain.tools.tool_node import ToolCallWithContext
|
||||
from langchain.tools.tool_node import ToolCallWithContext, _ToolNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
|
||||
from langchain_core.runnables import Runnable
|
||||
from langgraph.cache.base import BaseCache
|
||||
@@ -449,6 +445,70 @@ def _chain_tool_call_wrappers(
|
||||
return result
|
||||
|
||||
|
||||
def _chain_async_tool_call_wrappers(
|
||||
wrappers: Sequence[
|
||||
Callable[
|
||||
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
||||
Awaitable[ToolMessage | Command],
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
||||
Awaitable[ToolMessage | Command],
|
||||
]
|
||||
| None
|
||||
):
|
||||
"""Compose async wrappers into middleware stack (first = outermost).
|
||||
|
||||
Args:
|
||||
wrappers: Async wrappers in middleware order.
|
||||
|
||||
Returns:
|
||||
Composed async wrapper, or None if empty.
|
||||
"""
|
||||
if not wrappers:
|
||||
return None
|
||||
|
||||
if len(wrappers) == 1:
|
||||
return wrappers[0]
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
||||
Awaitable[ToolMessage | Command],
|
||||
],
|
||||
inner: Callable[
|
||||
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
||||
Awaitable[ToolMessage | Command],
|
||||
],
|
||||
) -> Callable[
|
||||
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
||||
Awaitable[ToolMessage | Command],
|
||||
]:
|
||||
"""Compose two async wrappers where outer wraps inner."""
|
||||
|
||||
async def composed(
|
||||
request: ToolCallRequest,
|
||||
execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
# Create an async callable that invokes inner with the original execute
|
||||
async def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
|
||||
return await inner(req, execute)
|
||||
|
||||
# Outer can call call_inner multiple times
|
||||
return await outer(request, call_inner)
|
||||
|
||||
return composed
|
||||
|
||||
# Chain all wrappers: first -> second -> ... -> last
|
||||
result = wrappers[-1]
|
||||
for wrapper in reversed(wrappers[:-1]):
|
||||
result = compose_two(wrapper, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def create_agent( # noqa: PLR0915
|
||||
model: str | BaseChatModel,
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
||||
@@ -576,9 +636,14 @@ def create_agent( # noqa: PLR0915
|
||||
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
||||
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
||||
|
||||
# Collect middleware with wrap_tool_call hooks
|
||||
# Collect middleware with wrap_tool_call or awrap_tool_call hooks
|
||||
# Include middleware with either implementation to ensure NotImplementedError is raised
|
||||
# when middleware doesn't support the execution path
|
||||
middleware_w_wrap_tool_call = [
|
||||
m for m in middleware if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
|
||||
or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
|
||||
]
|
||||
|
||||
# Chain all wrap_tool_call handlers into a single composed handler
|
||||
@@ -587,8 +652,24 @@ def create_agent( # noqa: PLR0915
|
||||
wrappers = [m.wrap_tool_call for m in middleware_w_wrap_tool_call]
|
||||
wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
|
||||
|
||||
# Collect middleware with awrap_tool_call or wrap_tool_call hooks
|
||||
# Include middleware with either implementation to ensure NotImplementedError is raised
|
||||
# when middleware doesn't support the execution path
|
||||
middleware_w_awrap_tool_call = [
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
|
||||
or m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
|
||||
]
|
||||
|
||||
# Chain all awrap_tool_call handlers into a single composed async handler
|
||||
awrap_tool_call_wrapper = None
|
||||
if middleware_w_awrap_tool_call:
|
||||
async_wrappers = [m.awrap_tool_call for m in middleware_w_awrap_tool_call]
|
||||
awrap_tool_call_wrapper = _chain_async_tool_call_wrappers(async_wrappers)
|
||||
|
||||
# Setup tools
|
||||
tool_node: ToolNode | None = None
|
||||
tool_node: _ToolNode | None = None
|
||||
# Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
|
||||
built_in_tools = [t for t in tools if isinstance(t, dict)]
|
||||
regular_tools = [t for t in tools if not isinstance(t, dict)]
|
||||
@@ -598,7 +679,11 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
# Only create ToolNode if we have client-side tools
|
||||
tool_node = (
|
||||
ToolNode(tools=available_tools, wrap_tool_call=wrap_tool_call_wrapper)
|
||||
_ToolNode(
|
||||
tools=available_tools,
|
||||
wrap_tool_call=wrap_tool_call_wrapper,
|
||||
awrap_tool_call=awrap_tool_call_wrapper,
|
||||
)
|
||||
if available_tools
|
||||
else None
|
||||
)
|
||||
@@ -640,13 +725,23 @@ def create_agent( # noqa: PLR0915
|
||||
if m.__class__.after_agent is not AgentMiddleware.after_agent
|
||||
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
|
||||
]
|
||||
# Collect middleware with wrap_model_call or awrap_model_call hooks
|
||||
# Include middleware with either implementation to ensure NotImplementedError is raised
|
||||
# when middleware doesn't support the execution path
|
||||
middleware_w_wrap_model_call = [
|
||||
m for m in middleware if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
|
||||
or m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
|
||||
]
|
||||
# Collect middleware with awrap_model_call or wrap_model_call hooks
|
||||
# Include middleware with either implementation to ensure NotImplementedError is raised
|
||||
# when middleware doesn't support the execution path
|
||||
middleware_w_awrap_model_call = [
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
|
||||
or m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
|
||||
]
|
||||
|
||||
# Compose wrap_model_call handlers into a single middleware stack (sync)
|
||||
@@ -1378,7 +1473,7 @@ def _make_model_to_model_edge(
|
||||
|
||||
def _make_tools_to_model_edge(
|
||||
*,
|
||||
tool_node: ToolNode,
|
||||
tool_node: _ToolNode,
|
||||
model_destination: str,
|
||||
structured_output_tools: dict[str, OutputToolBinding],
|
||||
end_destination: str,
|
||||
|
||||
@@ -20,7 +20,11 @@ from .tool_selection import LLMToolSelectorMiddleware
|
||||
from .types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelCallHandler,
|
||||
ModelCallResult,
|
||||
ModelCallWrapper,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
after_agent,
|
||||
after_model,
|
||||
before_agent,
|
||||
@@ -28,6 +32,7 @@ from .types import (
|
||||
dynamic_prompt,
|
||||
hook_config,
|
||||
wrap_model_call,
|
||||
wrap_tool_call,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -41,9 +46,13 @@ __all__ = [
|
||||
"InterruptOnConfig",
|
||||
"LLMToolEmulator",
|
||||
"LLMToolSelectorMiddleware",
|
||||
"ModelCallHandler",
|
||||
"ModelCallLimitMiddleware",
|
||||
"ModelCallResult",
|
||||
"ModelCallWrapper",
|
||||
"ModelFallbackMiddleware",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"PIIDetectionError",
|
||||
"PIIMiddleware",
|
||||
"PlanningMiddleware",
|
||||
@@ -56,4 +65,5 @@ __all__ = [
|
||||
"dynamic_prompt",
|
||||
"hook_config",
|
||||
"wrap_model_call",
|
||||
"wrap_tool_call",
|
||||
]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user