Merge branch 'master' into pgvectorstore-docs

This commit is contained in:
dishaprakash 2025-04-21 17:25:46 +00:00 committed by GitHub
commit ac7e73a531
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 235 additions and 44 deletions

View File

@ -11,6 +11,7 @@
import json import json
import os import os
import sys import sys
from datetime import datetime
from pathlib import Path from pathlib import Path
import toml import toml
@ -104,7 +105,7 @@ def skip_private_members(app, what, name, obj, skip, options):
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = "🦜🔗 LangChain" project = "🦜🔗 LangChain"
copyright = "2023, LangChain Inc" copyright = f"{datetime.now().year}, LangChain Inc"
author = "LangChain, Inc" author = "LangChain, Inc"
html_favicon = "_static/img/brand/favicon.png" html_favicon = "_static/img/brand/favicon.png"

View File

@ -36,10 +36,7 @@
"pip install oracledb" "pip install oracledb"
], ],
"metadata": { "metadata": {
"collapsed": false, "collapsed": false
"pycharm": {
"is_executing": true
}
} }
}, },
{ {
@ -51,10 +48,7 @@
"from settings import s" "from settings import s"
], ],
"metadata": { "metadata": {
"collapsed": false, "collapsed": false
"pycharm": {
"is_executing": true
}
} }
}, },
{ {
@ -97,16 +91,14 @@
"doc_2 = doc_loader_2.load()" "doc_2 = doc_loader_2.load()"
], ],
"metadata": { "metadata": {
"collapsed": false, "collapsed": false
"pycharm": {
"is_executing": true
}
} }
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [ "source": [
"With TLS authentication, wallet_location and wallet_password are not required." "With TLS authentication, wallet_location and wallet_password are not required.\n",
"Bind variable option is provided by argument \"parameters\"."
], ],
"metadata": { "metadata": {
"collapsed": false "collapsed": false
@ -117,6 +109,8 @@
"execution_count": null, "execution_count": null,
"outputs": [], "outputs": [],
"source": [ "source": [
"SQL_QUERY = \"select channel_id, channel_desc from sh.channels where channel_desc = :1 fetch first 5 rows only\"\n",
"\n",
"doc_loader_3 = OracleAutonomousDatabaseLoader(\n", "doc_loader_3 = OracleAutonomousDatabaseLoader(\n",
" query=SQL_QUERY,\n", " query=SQL_QUERY,\n",
" user=s.USERNAME,\n", " user=s.USERNAME,\n",
@ -124,6 +118,7 @@
" schema=s.SCHEMA,\n", " schema=s.SCHEMA,\n",
" config_dir=s.CONFIG_DIR,\n", " config_dir=s.CONFIG_DIR,\n",
" tns_name=s.TNS_NAME,\n", " tns_name=s.TNS_NAME,\n",
" parameters=[\"Direct Sales\"],\n",
")\n", ")\n",
"doc_3 = doc_loader_3.load()\n", "doc_3 = doc_loader_3.load()\n",
"\n", "\n",
@ -133,6 +128,7 @@
" password=s.PASSWORD,\n", " password=s.PASSWORD,\n",
" schema=s.SCHEMA,\n", " schema=s.SCHEMA,\n",
" connection_string=s.CONNECTION_STRING,\n", " connection_string=s.CONNECTION_STRING,\n",
" parameters=[\"Direct Sales\"],\n",
")\n", ")\n",
"doc_4 = doc_loader_4.load()" "doc_4 = doc_loader_4.load()"
], ],

View File

@ -183,7 +183,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Initalize simple_datasource_qa for querying Tableau Datasources through VDS\n", "# Initialize simple_datasource_qa for querying Tableau Datasources through VDS\n",
"analyze_datasource = initialize_simple_datasource_qa(\n", "analyze_datasource = initialize_simple_datasource_qa(\n",
" domain=tableau_server,\n", " domain=tableau_server,\n",
" site=tableau_site,\n", " site=tableau_site,\n",

View File

@ -10,6 +10,38 @@ from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.outputs import ChatGeneration, LLMResult
MODEL_COST_PER_1K_TOKENS = { MODEL_COST_PER_1K_TOKENS = {
# GPT-4.1 input
"gpt-4.1": 0.002,
"gpt-4.1-2025-04-14": 0.002,
"gpt-4.1-cached": 0.0005,
"gpt-4.1-2025-04-14-cached": 0.0005,
# GPT-4.1 output
"gpt-4.1-completion": 0.008,
"gpt-4.1-2025-04-14-completion": 0.008,
# GPT-4.1-mini input
"gpt-4.1-mini": 0.0004,
"gpt-4.1-mini-2025-04-14": 0.0004,
"gpt-4.1-mini-cached": 0.0001,
"gpt-4.1-mini-2025-04-14-cached": 0.0001,
# GPT-4.1-mini output
"gpt-4.1-mini-completion": 0.0016,
"gpt-4.1-mini-2025-04-14-completion": 0.0016,
# GPT-4.1-nano input
"gpt-4.1-nano": 0.0001,
"gpt-4.1-nano-2025-04-14": 0.0001,
"gpt-4.1-nano-cached": 0.000025,
"gpt-4.1-nano-2025-04-14-cached": 0.000025,
# GPT-4.1-nano output
"gpt-4.1-nano-completion": 0.0004,
"gpt-4.1-nano-2025-04-14-completion": 0.0004,
# GPT-4.5-preview input
"gpt-4.5-preview": 0.075,
"gpt-4.5-preview-2025-02-27": 0.075,
"gpt-4.5-preview-cached": 0.0375,
"gpt-4.5-preview-2025-02-27-cached": 0.0375,
# GPT-4.5-preview output
"gpt-4.5-preview-completion": 0.15,
"gpt-4.5-preview-2025-02-27-completion": 0.15,
# OpenAI o1 input # OpenAI o1 input
"o1": 0.015, "o1": 0.015,
"o1-2024-12-17": 0.015, "o1-2024-12-17": 0.015,
@ -18,6 +50,28 @@ MODEL_COST_PER_1K_TOKENS = {
# OpenAI o1 output # OpenAI o1 output
"o1-completion": 0.06, "o1-completion": 0.06,
"o1-2024-12-17-completion": 0.06, "o1-2024-12-17-completion": 0.06,
# OpenAI o1-pro input
"o1-pro": 0.15,
"o1-pro-2025-03-19": 0.15,
# OpenAI o1-pro output
"o1-pro-completion": 0.6,
"o1-pro-2025-03-19-completion": 0.6,
# OpenAI o3 input
"o3": 0.01,
"o3-2025-04-16": 0.01,
"o3-cached": 0.0025,
"o3-2025-04-16-cached": 0.0025,
# OpenAI o3 output
"o3-completion": 0.04,
"o3-2025-04-16-completion": 0.04,
# OpenAI o4-mini input
"o4-mini": 0.0011,
"o4-mini-2025-04-16": 0.0011,
"o4-mini-cached": 0.000275,
"o4-mini-2025-04-16-cached": 0.000275,
# OpenAI o4-mini output
"o4-mini-completion": 0.0044,
"o4-mini-2025-04-16-completion": 0.0044,
# OpenAI o3-mini input # OpenAI o3-mini input
"o3-mini": 0.0011, "o3-mini": 0.0011,
"o3-mini-2025-01-31": 0.0011, "o3-mini-2025-01-31": 0.0011,
@ -26,6 +80,14 @@ MODEL_COST_PER_1K_TOKENS = {
# OpenAI o3-mini output # OpenAI o3-mini output
"o3-mini-completion": 0.0044, "o3-mini-completion": 0.0044,
"o3-mini-2025-01-31-completion": 0.0044, "o3-mini-2025-01-31-completion": 0.0044,
# OpenAI o1-mini input (updated pricing)
"o1-mini": 0.0011,
"o1-mini-cached": 0.00055,
"o1-mini-2024-09-12": 0.0011,
"o1-mini-2024-09-12-cached": 0.00055,
# OpenAI o1-mini output (updated pricing)
"o1-mini-completion": 0.0044,
"o1-mini-2024-09-12-completion": 0.0044,
# OpenAI o1-preview input # OpenAI o1-preview input
"o1-preview": 0.015, "o1-preview": 0.015,
"o1-preview-cached": 0.0075, "o1-preview-cached": 0.0075,
@ -34,22 +96,6 @@ MODEL_COST_PER_1K_TOKENS = {
# OpenAI o1-preview output # OpenAI o1-preview output
"o1-preview-completion": 0.06, "o1-preview-completion": 0.06,
"o1-preview-2024-09-12-completion": 0.06, "o1-preview-2024-09-12-completion": 0.06,
# OpenAI o1-mini input
"o1-mini": 0.003,
"o1-mini-cached": 0.0015,
"o1-mini-2024-09-12": 0.003,
"o1-mini-2024-09-12-cached": 0.0015,
# OpenAI o1-mini output
"o1-mini-completion": 0.012,
"o1-mini-2024-09-12-completion": 0.012,
# GPT-4o-mini input
"gpt-4o-mini": 0.00015,
"gpt-4o-mini-cached": 0.000075,
"gpt-4o-mini-2024-07-18": 0.00015,
"gpt-4o-mini-2024-07-18-cached": 0.000075,
# GPT-4o-mini output
"gpt-4o-mini-completion": 0.0006,
"gpt-4o-mini-2024-07-18-completion": 0.0006,
# GPT-4o input # GPT-4o input
"gpt-4o": 0.0025, "gpt-4o": 0.0025,
"gpt-4o-cached": 0.00125, "gpt-4o-cached": 0.00125,
@ -63,6 +109,65 @@ MODEL_COST_PER_1K_TOKENS = {
"gpt-4o-2024-05-13-completion": 0.015, "gpt-4o-2024-05-13-completion": 0.015,
"gpt-4o-2024-08-06-completion": 0.01, "gpt-4o-2024-08-06-completion": 0.01,
"gpt-4o-2024-11-20-completion": 0.01, "gpt-4o-2024-11-20-completion": 0.01,
# GPT-4o-audio-preview input
"gpt-4o-audio-preview": 0.0025,
"gpt-4o-audio-preview-2024-12-17": 0.0025,
"gpt-4o-audio-preview-2024-10-01": 0.0025,
# GPT-4o-audio-preview output
"gpt-4o-audio-preview-completion": 0.01,
"gpt-4o-audio-preview-2024-12-17-completion": 0.01,
"gpt-4o-audio-preview-2024-10-01-completion": 0.01,
# GPT-4o-realtime-preview input
"gpt-4o-realtime-preview": 0.005,
"gpt-4o-realtime-preview-2024-12-17": 0.005,
"gpt-4o-realtime-preview-2024-10-01": 0.005,
"gpt-4o-realtime-preview-cached": 0.0025,
"gpt-4o-realtime-preview-2024-12-17-cached": 0.0025,
"gpt-4o-realtime-preview-2024-10-01-cached": 0.0025,
# GPT-4o-realtime-preview output
"gpt-4o-realtime-preview-completion": 0.02,
"gpt-4o-realtime-preview-2024-12-17-completion": 0.02,
"gpt-4o-realtime-preview-2024-10-01-completion": 0.02,
# GPT-4o-mini input
"gpt-4o-mini": 0.00015,
"gpt-4o-mini-cached": 0.000075,
"gpt-4o-mini-2024-07-18": 0.00015,
"gpt-4o-mini-2024-07-18-cached": 0.000075,
# GPT-4o-mini output
"gpt-4o-mini-completion": 0.0006,
"gpt-4o-mini-2024-07-18-completion": 0.0006,
# GPT-4o-mini-audio-preview input
"gpt-4o-mini-audio-preview": 0.00015,
"gpt-4o-mini-audio-preview-2024-12-17": 0.00015,
# GPT-4o-mini-audio-preview output
"gpt-4o-mini-audio-preview-completion": 0.0006,
"gpt-4o-mini-audio-preview-2024-12-17-completion": 0.0006,
# GPT-4o-mini-realtime-preview input
"gpt-4o-mini-realtime-preview": 0.0006,
"gpt-4o-mini-realtime-preview-2024-12-17": 0.0006,
"gpt-4o-mini-realtime-preview-cached": 0.0003,
"gpt-4o-mini-realtime-preview-2024-12-17-cached": 0.0003,
# GPT-4o-mini-realtime-preview output
"gpt-4o-mini-realtime-preview-completion": 0.0024,
"gpt-4o-mini-realtime-preview-2024-12-17-completion": 0.0024,
# GPT-4o-mini-search-preview input
"gpt-4o-mini-search-preview": 0.00015,
"gpt-4o-mini-search-preview-2025-03-11": 0.00015,
# GPT-4o-mini-search-preview output
"gpt-4o-mini-search-preview-completion": 0.0006,
"gpt-4o-mini-search-preview-2025-03-11-completion": 0.0006,
# GPT-4o-search-preview input
"gpt-4o-search-preview": 0.0025,
"gpt-4o-search-preview-2025-03-11": 0.0025,
# GPT-4o-search-preview output
"gpt-4o-search-preview-completion": 0.01,
"gpt-4o-search-preview-2025-03-11-completion": 0.01,
# Computer-use-preview input
"computer-use-preview": 0.003,
"computer-use-preview-2025-03-11": 0.003,
# Computer-use-preview output
"computer-use-preview-completion": 0.012,
"computer-use-preview-2025-03-11-completion": 0.012,
# GPT-4 input # GPT-4 input
"gpt-4": 0.03, "gpt-4": 0.03,
"gpt-4-0314": 0.03, "gpt-4-0314": 0.03,
@ -219,6 +324,7 @@ def standardize_model_name(
or model_name.startswith("gpt-35") or model_name.startswith("gpt-35")
or model_name.startswith("o1-") or model_name.startswith("o1-")
or model_name.startswith("o3-") or model_name.startswith("o3-")
or model_name.startswith("o4-")
or ("finetuned" in model_name and "legacy" not in model_name) or ("finetuned" in model_name and "legacy" not in model_name)
): ):
return model_name + "-completion" return model_name + "-completion"
@ -226,8 +332,10 @@ def standardize_model_name(
token_type == TokenType.PROMPT_CACHED token_type == TokenType.PROMPT_CACHED
and ( and (
model_name.startswith("gpt-4o") model_name.startswith("gpt-4o")
or model_name.startswith("gpt-4.1")
or model_name.startswith("o1") or model_name.startswith("o1")
or model_name.startswith("o3") or model_name.startswith("o3")
or model_name.startswith("o4")
) )
and not (model_name.startswith("gpt-4o-2024-05-13")) and not (model_name.startswith("gpt-4o-2024-05-13"))
): ):

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
from langchain_core.documents import Document from langchain_core.documents import Document
@ -31,6 +31,7 @@ class OracleAutonomousDatabaseLoader(BaseLoader):
wallet_password: Optional[str] = None, wallet_password: Optional[str] = None,
connection_string: Optional[str] = None, connection_string: Optional[str] = None,
metadata: Optional[List[str]] = None, metadata: Optional[List[str]] = None,
parameters: Optional[Union[list, tuple, dict]] = None,
): ):
""" """
init method init method
@ -44,6 +45,7 @@ class OracleAutonomousDatabaseLoader(BaseLoader):
:param wallet_password: password of wallet :param wallet_password: password of wallet
:param connection_string: connection string to connect to adb instance :param connection_string: connection string to connect to adb instance
:param metadata: metadata used in document :param metadata: metadata used in document
:param parameters: bind variable to use in query
""" """
# Mandatory required arguments. # Mandatory required arguments.
self.query = query self.query = query
@ -67,6 +69,9 @@ class OracleAutonomousDatabaseLoader(BaseLoader):
# metadata column # metadata column
self.metadata = metadata self.metadata = metadata
# parameters, e.g bind variable
self.parameters = parameters
# dsn # dsn
self.dsn: Optional[str] self.dsn: Optional[str]
self._set_dsn() self._set_dsn()
@ -96,7 +101,10 @@ class OracleAutonomousDatabaseLoader(BaseLoader):
cursor = connection.cursor() cursor = connection.cursor()
if self.schema: if self.schema:
cursor.execute(f"alter session set current_schema={self.schema}") cursor.execute(f"alter session set current_schema={self.schema}")
cursor.execute(self.query) if self.parameters:
cursor.execute(self.query, self.parameters)
else:
cursor.execute(self.query)
columns = [col[0] for col in cursor.description] columns = [col[0] for col in cursor.description]
data = cursor.fetchall() data = cursor.fetchall()
data = [ data = [

View File

@ -668,7 +668,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# NOTE: to keep things simple, we assume the list may contain texts longer # NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function. # than the maximum context and use length-safe embedding function.
engine = cast(str, self.deployment) engine = cast(str, self.deployment)
return self._get_len_safe_embeddings(texts, engine=engine) return self._get_len_safe_embeddings(
texts, engine=engine, chunk_size=chunk_size
)
async def aembed_documents( async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0 self, texts: List[str], chunk_size: Optional[int] = 0
@ -686,7 +688,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# NOTE: to keep things simple, we assume the list may contain texts longer # NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function. # than the maximum context and use length-safe embedding function.
engine = cast(str, self.deployment) engine = cast(str, self.deployment)
return await self._aget_len_safe_embeddings(texts, engine=engine) return self._get_len_safe_embeddings(
texts, engine=engine, chunk_size=chunk_size
)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text. """Call out to OpenAI's embedding endpoint for embedding query text.

View File

@ -1,7 +1,12 @@
import os
from unittest.mock import patch
import pytest import pytest
from langchain_community.embeddings.openai import OpenAIEmbeddings from langchain_community.embeddings.openai import OpenAIEmbeddings
os.environ["OPENAI_API_KEY"] = "foo"
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_openai_invalid_model_kwargs() -> None: def test_openai_invalid_model_kwargs() -> None:
@ -14,3 +19,20 @@ def test_openai_incorrect_field() -> None:
with pytest.warns(match="not default parameter"): with pytest.warns(match="not default parameter"):
llm = OpenAIEmbeddings(foo="bar", openai_api_key="foo") # type: ignore[call-arg] llm = OpenAIEmbeddings(foo="bar", openai_api_key="foo") # type: ignore[call-arg]
assert llm.model_kwargs == {"foo": "bar"} assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.requires("openai")
def test_embed_documents_with_custom_chunk_size() -> None:
embeddings = OpenAIEmbeddings(chunk_size=2)
texts = ["text1", "text2", "text3", "text4"]
custom_chunk_size = 3
with patch.object(embeddings.client, "create") as mock_create:
mock_create.side_effect = [
{"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]},
{"data": [{"embedding": [0.5, 0.6]}, {"embedding": [0.7, 0.8]}]},
]
embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params)
mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params)

View File

@ -2,6 +2,7 @@
import copy import copy
import json import json
import logging
from json import JSONDecodeError from json import JSONDecodeError
from typing import Annotated, Any, Optional from typing import Annotated, Any, Optional
@ -16,6 +17,8 @@ from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils.json import parse_partial_json from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.pydantic import TypeBaseModel from langchain_core.utils.pydantic import TypeBaseModel
logger = logging.getLogger(__name__)
def parse_tool_call( def parse_tool_call(
raw_tool_call: dict[str, Any], raw_tool_call: dict[str, Any],
@ -250,6 +253,14 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
return parsed_result return parsed_result
# Common cause of ValidationError is truncated output due to max_tokens.
_MAX_TOKENS_ERROR = (
"Output parser received a `max_tokens` stop reason. "
"The output is likely incomplete—please increase `max_tokens` "
"or shorten your prompt."
)
class PydanticToolsParser(JsonOutputToolsParser): class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response.""" """Parse tools from OpenAI response."""
@ -296,6 +307,14 @@ class PydanticToolsParser(JsonOutputToolsParser):
except (ValidationError, ValueError): except (ValidationError, ValueError):
if partial: if partial:
continue continue
has_max_tokens_stop_reason = any(
generation.message.response_metadata.get("stop_reason")
== "max_tokens"
for generation in result
if isinstance(generation, ChatGeneration)
)
if has_max_tokens_stop_reason:
logger.exception(_MAX_TOKENS_ERROR)
raise raise
if self.first_tool_only: if self.first_tool_only:
return pydantic_objects[0] if pydantic_objects else None return pydantic_objects[0] if pydantic_objects else None

View File

@ -108,7 +108,7 @@ class Node(NamedTuple):
id: str id: str
name: str name: str
data: Union[type[BaseModel], RunnableType] data: Union[type[BaseModel], RunnableType, None]
metadata: Optional[dict[str, Any]] metadata: Optional[dict[str, Any]]
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node: def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
@ -181,7 +181,7 @@ class MermaidDrawMethod(Enum):
API = "api" # Uses Mermaid.INK API to render the graph API = "api" # Uses Mermaid.INK API to render the graph
def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str: def node_data_str(id: str, data: Union[type[BaseModel], RunnableType, None]) -> str:
"""Convert the data of a node to a string. """Convert the data of a node to a string.
Args: Args:
@ -193,7 +193,7 @@ def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str:
""" """
from langchain_core.runnables.base import Runnable from langchain_core.runnables.base import Runnable
if not is_uuid(id): if not is_uuid(id) or data is None:
return id return id
data_str = data.get_name() if isinstance(data, Runnable) else data.__name__ data_str = data.get_name() if isinstance(data, Runnable) else data.__name__
return data_str if not data_str.startswith("Runnable") else data_str[8:] return data_str if not data_str.startswith("Runnable") else data_str[8:]
@ -215,8 +215,10 @@ def node_data_json(
from langchain_core.load.serializable import to_json_not_implemented from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.base import Runnable, RunnableSerializable
if isinstance(node.data, RunnableSerializable): if node.data is None:
json: dict[str, Any] = { json: dict[str, Any] = {}
elif isinstance(node.data, RunnableSerializable):
json = {
"type": "runnable", "type": "runnable",
"data": { "data": {
"id": node.data.lc_id(), "id": node.data.lc_id(),
@ -317,7 +319,7 @@ class Graph:
def add_node( def add_node(
self, self,
data: Union[type[BaseModel], RunnableType], data: Union[type[BaseModel], RunnableType, None],
id: Optional[str] = None, id: Optional[str] = None,
*, *,
metadata: Optional[dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,

View File

@ -2,7 +2,7 @@ from collections.abc import AsyncIterator, Iterator
from typing import Any from typing import Any
import pytest import pytest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, ValidationError
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
@ -635,3 +635,24 @@ def test_parse_with_different_pydantic_1_proper() -> None:
forecast="Sunny", forecast="Sunny",
) )
] ]
def test_max_tokens_error(caplog: Any) -> None:
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
input = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "NameCollector",
"args": {"names": ["suz", "jerm"]},
}
],
response_metadata={"stop_reason": "max_tokens"},
)
with pytest.raises(ValidationError):
_ = parser.invoke(input)
assert any(
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
for record, msg in zip(caplog.records, caplog.messages)
)

View File

@ -82,8 +82,8 @@ class AnthropicTool(TypedDict):
"""Anthropic tool definition.""" """Anthropic tool definition."""
name: str name: str
description: str
input_schema: dict[str, Any] input_schema: dict[str, Any]
description: NotRequired[str]
cache_control: NotRequired[dict[str, str]] cache_control: NotRequired[dict[str, str]]
@ -1675,9 +1675,10 @@ def convert_to_anthropic_tool(
oai_formatted = convert_to_openai_tool(tool)["function"] oai_formatted = convert_to_openai_tool(tool)["function"]
anthropic_formatted = AnthropicTool( anthropic_formatted = AnthropicTool(
name=oai_formatted["name"], name=oai_formatted["name"],
description=oai_formatted["description"],
input_schema=oai_formatted["parameters"], input_schema=oai_formatted["parameters"],
) )
if "description" in oai_formatted:
anthropic_formatted["description"] = oai_formatted["description"]
return anthropic_formatted return anthropic_formatted

View File

@ -931,3 +931,12 @@ def test_anthropic_bind_tools_tool_choice() -> None:
assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == { assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == {
"type": "any" "type": "any"
} }
def test_optional_description() -> None:
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
class SampleModel(BaseModel):
sample_field: str
_ = llm.with_structured_output(SampleModel.model_json_schema())