mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
Merge branch 'master' into pgvectorstore-docs
This commit is contained in:
commit
ac7e73a531
@ -11,6 +11,7 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import toml
|
||||
@ -104,7 +105,7 @@ def skip_private_members(app, what, name, obj, skip, options):
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "🦜🔗 LangChain"
|
||||
copyright = "2023, LangChain Inc"
|
||||
copyright = f"{datetime.now().year}, LangChain Inc"
|
||||
author = "LangChain, Inc"
|
||||
|
||||
html_favicon = "_static/img/brand/favicon.png"
|
||||
|
@ -36,10 +36,7 @@
|
||||
"pip install oracledb"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -51,10 +48,7 @@
|
||||
"from settings import s"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -97,16 +91,14 @@
|
||||
"doc_2 = doc_loader_2.load()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"With TLS authentication, wallet_location and wallet_password are not required."
|
||||
"With TLS authentication, wallet_location and wallet_password are not required.\n",
|
||||
"Bind variable option is provided by argument \"parameters\"."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
@ -117,6 +109,8 @@
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SQL_QUERY = \"select channel_id, channel_desc from sh.channels where channel_desc = :1 fetch first 5 rows only\"\n",
|
||||
"\n",
|
||||
"doc_loader_3 = OracleAutonomousDatabaseLoader(\n",
|
||||
" query=SQL_QUERY,\n",
|
||||
" user=s.USERNAME,\n",
|
||||
@ -124,6 +118,7 @@
|
||||
" schema=s.SCHEMA,\n",
|
||||
" config_dir=s.CONFIG_DIR,\n",
|
||||
" tns_name=s.TNS_NAME,\n",
|
||||
" parameters=[\"Direct Sales\"],\n",
|
||||
")\n",
|
||||
"doc_3 = doc_loader_3.load()\n",
|
||||
"\n",
|
||||
@ -133,6 +128,7 @@
|
||||
" password=s.PASSWORD,\n",
|
||||
" schema=s.SCHEMA,\n",
|
||||
" connection_string=s.CONNECTION_STRING,\n",
|
||||
" parameters=[\"Direct Sales\"],\n",
|
||||
")\n",
|
||||
"doc_4 = doc_loader_4.load()"
|
||||
],
|
||||
|
@ -183,7 +183,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initalize simple_datasource_qa for querying Tableau Datasources through VDS\n",
|
||||
"# Initialize simple_datasource_qa for querying Tableau Datasources through VDS\n",
|
||||
"analyze_datasource = initialize_simple_datasource_qa(\n",
|
||||
" domain=tableau_server,\n",
|
||||
" site=tableau_site,\n",
|
||||
|
@ -10,6 +10,38 @@ from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
MODEL_COST_PER_1K_TOKENS = {
|
||||
# GPT-4.1 input
|
||||
"gpt-4.1": 0.002,
|
||||
"gpt-4.1-2025-04-14": 0.002,
|
||||
"gpt-4.1-cached": 0.0005,
|
||||
"gpt-4.1-2025-04-14-cached": 0.0005,
|
||||
# GPT-4.1 output
|
||||
"gpt-4.1-completion": 0.008,
|
||||
"gpt-4.1-2025-04-14-completion": 0.008,
|
||||
# GPT-4.1-mini input
|
||||
"gpt-4.1-mini": 0.0004,
|
||||
"gpt-4.1-mini-2025-04-14": 0.0004,
|
||||
"gpt-4.1-mini-cached": 0.0001,
|
||||
"gpt-4.1-mini-2025-04-14-cached": 0.0001,
|
||||
# GPT-4.1-mini output
|
||||
"gpt-4.1-mini-completion": 0.0016,
|
||||
"gpt-4.1-mini-2025-04-14-completion": 0.0016,
|
||||
# GPT-4.1-nano input
|
||||
"gpt-4.1-nano": 0.0001,
|
||||
"gpt-4.1-nano-2025-04-14": 0.0001,
|
||||
"gpt-4.1-nano-cached": 0.000025,
|
||||
"gpt-4.1-nano-2025-04-14-cached": 0.000025,
|
||||
# GPT-4.1-nano output
|
||||
"gpt-4.1-nano-completion": 0.0004,
|
||||
"gpt-4.1-nano-2025-04-14-completion": 0.0004,
|
||||
# GPT-4.5-preview input
|
||||
"gpt-4.5-preview": 0.075,
|
||||
"gpt-4.5-preview-2025-02-27": 0.075,
|
||||
"gpt-4.5-preview-cached": 0.0375,
|
||||
"gpt-4.5-preview-2025-02-27-cached": 0.0375,
|
||||
# GPT-4.5-preview output
|
||||
"gpt-4.5-preview-completion": 0.15,
|
||||
"gpt-4.5-preview-2025-02-27-completion": 0.15,
|
||||
# OpenAI o1 input
|
||||
"o1": 0.015,
|
||||
"o1-2024-12-17": 0.015,
|
||||
@ -18,6 +50,28 @@ MODEL_COST_PER_1K_TOKENS = {
|
||||
# OpenAI o1 output
|
||||
"o1-completion": 0.06,
|
||||
"o1-2024-12-17-completion": 0.06,
|
||||
# OpenAI o1-pro input
|
||||
"o1-pro": 0.15,
|
||||
"o1-pro-2025-03-19": 0.15,
|
||||
# OpenAI o1-pro output
|
||||
"o1-pro-completion": 0.6,
|
||||
"o1-pro-2025-03-19-completion": 0.6,
|
||||
# OpenAI o3 input
|
||||
"o3": 0.01,
|
||||
"o3-2025-04-16": 0.01,
|
||||
"o3-cached": 0.0025,
|
||||
"o3-2025-04-16-cached": 0.0025,
|
||||
# OpenAI o3 output
|
||||
"o3-completion": 0.04,
|
||||
"o3-2025-04-16-completion": 0.04,
|
||||
# OpenAI o4-mini input
|
||||
"o4-mini": 0.0011,
|
||||
"o4-mini-2025-04-16": 0.0011,
|
||||
"o4-mini-cached": 0.000275,
|
||||
"o4-mini-2025-04-16-cached": 0.000275,
|
||||
# OpenAI o4-mini output
|
||||
"o4-mini-completion": 0.0044,
|
||||
"o4-mini-2025-04-16-completion": 0.0044,
|
||||
# OpenAI o3-mini input
|
||||
"o3-mini": 0.0011,
|
||||
"o3-mini-2025-01-31": 0.0011,
|
||||
@ -26,6 +80,14 @@ MODEL_COST_PER_1K_TOKENS = {
|
||||
# OpenAI o3-mini output
|
||||
"o3-mini-completion": 0.0044,
|
||||
"o3-mini-2025-01-31-completion": 0.0044,
|
||||
# OpenAI o1-mini input (updated pricing)
|
||||
"o1-mini": 0.0011,
|
||||
"o1-mini-cached": 0.00055,
|
||||
"o1-mini-2024-09-12": 0.0011,
|
||||
"o1-mini-2024-09-12-cached": 0.00055,
|
||||
# OpenAI o1-mini output (updated pricing)
|
||||
"o1-mini-completion": 0.0044,
|
||||
"o1-mini-2024-09-12-completion": 0.0044,
|
||||
# OpenAI o1-preview input
|
||||
"o1-preview": 0.015,
|
||||
"o1-preview-cached": 0.0075,
|
||||
@ -34,22 +96,6 @@ MODEL_COST_PER_1K_TOKENS = {
|
||||
# OpenAI o1-preview output
|
||||
"o1-preview-completion": 0.06,
|
||||
"o1-preview-2024-09-12-completion": 0.06,
|
||||
# OpenAI o1-mini input
|
||||
"o1-mini": 0.003,
|
||||
"o1-mini-cached": 0.0015,
|
||||
"o1-mini-2024-09-12": 0.003,
|
||||
"o1-mini-2024-09-12-cached": 0.0015,
|
||||
# OpenAI o1-mini output
|
||||
"o1-mini-completion": 0.012,
|
||||
"o1-mini-2024-09-12-completion": 0.012,
|
||||
# GPT-4o-mini input
|
||||
"gpt-4o-mini": 0.00015,
|
||||
"gpt-4o-mini-cached": 0.000075,
|
||||
"gpt-4o-mini-2024-07-18": 0.00015,
|
||||
"gpt-4o-mini-2024-07-18-cached": 0.000075,
|
||||
# GPT-4o-mini output
|
||||
"gpt-4o-mini-completion": 0.0006,
|
||||
"gpt-4o-mini-2024-07-18-completion": 0.0006,
|
||||
# GPT-4o input
|
||||
"gpt-4o": 0.0025,
|
||||
"gpt-4o-cached": 0.00125,
|
||||
@ -63,6 +109,65 @@ MODEL_COST_PER_1K_TOKENS = {
|
||||
"gpt-4o-2024-05-13-completion": 0.015,
|
||||
"gpt-4o-2024-08-06-completion": 0.01,
|
||||
"gpt-4o-2024-11-20-completion": 0.01,
|
||||
# GPT-4o-audio-preview input
|
||||
"gpt-4o-audio-preview": 0.0025,
|
||||
"gpt-4o-audio-preview-2024-12-17": 0.0025,
|
||||
"gpt-4o-audio-preview-2024-10-01": 0.0025,
|
||||
# GPT-4o-audio-preview output
|
||||
"gpt-4o-audio-preview-completion": 0.01,
|
||||
"gpt-4o-audio-preview-2024-12-17-completion": 0.01,
|
||||
"gpt-4o-audio-preview-2024-10-01-completion": 0.01,
|
||||
# GPT-4o-realtime-preview input
|
||||
"gpt-4o-realtime-preview": 0.005,
|
||||
"gpt-4o-realtime-preview-2024-12-17": 0.005,
|
||||
"gpt-4o-realtime-preview-2024-10-01": 0.005,
|
||||
"gpt-4o-realtime-preview-cached": 0.0025,
|
||||
"gpt-4o-realtime-preview-2024-12-17-cached": 0.0025,
|
||||
"gpt-4o-realtime-preview-2024-10-01-cached": 0.0025,
|
||||
# GPT-4o-realtime-preview output
|
||||
"gpt-4o-realtime-preview-completion": 0.02,
|
||||
"gpt-4o-realtime-preview-2024-12-17-completion": 0.02,
|
||||
"gpt-4o-realtime-preview-2024-10-01-completion": 0.02,
|
||||
# GPT-4o-mini input
|
||||
"gpt-4o-mini": 0.00015,
|
||||
"gpt-4o-mini-cached": 0.000075,
|
||||
"gpt-4o-mini-2024-07-18": 0.00015,
|
||||
"gpt-4o-mini-2024-07-18-cached": 0.000075,
|
||||
# GPT-4o-mini output
|
||||
"gpt-4o-mini-completion": 0.0006,
|
||||
"gpt-4o-mini-2024-07-18-completion": 0.0006,
|
||||
# GPT-4o-mini-audio-preview input
|
||||
"gpt-4o-mini-audio-preview": 0.00015,
|
||||
"gpt-4o-mini-audio-preview-2024-12-17": 0.00015,
|
||||
# GPT-4o-mini-audio-preview output
|
||||
"gpt-4o-mini-audio-preview-completion": 0.0006,
|
||||
"gpt-4o-mini-audio-preview-2024-12-17-completion": 0.0006,
|
||||
# GPT-4o-mini-realtime-preview input
|
||||
"gpt-4o-mini-realtime-preview": 0.0006,
|
||||
"gpt-4o-mini-realtime-preview-2024-12-17": 0.0006,
|
||||
"gpt-4o-mini-realtime-preview-cached": 0.0003,
|
||||
"gpt-4o-mini-realtime-preview-2024-12-17-cached": 0.0003,
|
||||
# GPT-4o-mini-realtime-preview output
|
||||
"gpt-4o-mini-realtime-preview-completion": 0.0024,
|
||||
"gpt-4o-mini-realtime-preview-2024-12-17-completion": 0.0024,
|
||||
# GPT-4o-mini-search-preview input
|
||||
"gpt-4o-mini-search-preview": 0.00015,
|
||||
"gpt-4o-mini-search-preview-2025-03-11": 0.00015,
|
||||
# GPT-4o-mini-search-preview output
|
||||
"gpt-4o-mini-search-preview-completion": 0.0006,
|
||||
"gpt-4o-mini-search-preview-2025-03-11-completion": 0.0006,
|
||||
# GPT-4o-search-preview input
|
||||
"gpt-4o-search-preview": 0.0025,
|
||||
"gpt-4o-search-preview-2025-03-11": 0.0025,
|
||||
# GPT-4o-search-preview output
|
||||
"gpt-4o-search-preview-completion": 0.01,
|
||||
"gpt-4o-search-preview-2025-03-11-completion": 0.01,
|
||||
# Computer-use-preview input
|
||||
"computer-use-preview": 0.003,
|
||||
"computer-use-preview-2025-03-11": 0.003,
|
||||
# Computer-use-preview output
|
||||
"computer-use-preview-completion": 0.012,
|
||||
"computer-use-preview-2025-03-11-completion": 0.012,
|
||||
# GPT-4 input
|
||||
"gpt-4": 0.03,
|
||||
"gpt-4-0314": 0.03,
|
||||
@ -219,6 +324,7 @@ def standardize_model_name(
|
||||
or model_name.startswith("gpt-35")
|
||||
or model_name.startswith("o1-")
|
||||
or model_name.startswith("o3-")
|
||||
or model_name.startswith("o4-")
|
||||
or ("finetuned" in model_name and "legacy" not in model_name)
|
||||
):
|
||||
return model_name + "-completion"
|
||||
@ -226,8 +332,10 @@ def standardize_model_name(
|
||||
token_type == TokenType.PROMPT_CACHED
|
||||
and (
|
||||
model_name.startswith("gpt-4o")
|
||||
or model_name.startswith("gpt-4.1")
|
||||
or model_name.startswith("o1")
|
||||
or model_name.startswith("o3")
|
||||
or model_name.startswith("o4")
|
||||
)
|
||||
and not (model_name.startswith("gpt-4o-2024-05-13"))
|
||||
):
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@ -31,6 +31,7 @@ class OracleAutonomousDatabaseLoader(BaseLoader):
|
||||
wallet_password: Optional[str] = None,
|
||||
connection_string: Optional[str] = None,
|
||||
metadata: Optional[List[str]] = None,
|
||||
parameters: Optional[Union[list, tuple, dict]] = None,
|
||||
):
|
||||
"""
|
||||
init method
|
||||
@ -44,6 +45,7 @@ class OracleAutonomousDatabaseLoader(BaseLoader):
|
||||
:param wallet_password: password of wallet
|
||||
:param connection_string: connection string to connect to adb instance
|
||||
:param metadata: metadata used in document
|
||||
:param parameters: bind variable to use in query
|
||||
"""
|
||||
# Mandatory required arguments.
|
||||
self.query = query
|
||||
@ -67,6 +69,9 @@ class OracleAutonomousDatabaseLoader(BaseLoader):
|
||||
# metadata column
|
||||
self.metadata = metadata
|
||||
|
||||
# parameters, e.g bind variable
|
||||
self.parameters = parameters
|
||||
|
||||
# dsn
|
||||
self.dsn: Optional[str]
|
||||
self._set_dsn()
|
||||
@ -96,7 +101,10 @@ class OracleAutonomousDatabaseLoader(BaseLoader):
|
||||
cursor = connection.cursor()
|
||||
if self.schema:
|
||||
cursor.execute(f"alter session set current_schema={self.schema}")
|
||||
cursor.execute(self.query)
|
||||
if self.parameters:
|
||||
cursor.execute(self.query, self.parameters)
|
||||
else:
|
||||
cursor.execute(self.query)
|
||||
columns = [col[0] for col in cursor.description]
|
||||
data = cursor.fetchall()
|
||||
data = [
|
||||
|
@ -668,7 +668,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
engine = cast(str, self.deployment)
|
||||
return self._get_len_safe_embeddings(texts, engine=engine)
|
||||
return self._get_len_safe_embeddings(
|
||||
texts, engine=engine, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
async def aembed_documents(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
@ -686,7 +688,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||
# than the maximum context and use length-safe embedding function.
|
||||
engine = cast(str, self.deployment)
|
||||
return await self._aget_len_safe_embeddings(texts, engine=engine)
|
||||
return self._get_len_safe_embeddings(
|
||||
texts, engine=engine, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||
|
@ -1,7 +1,12 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_invalid_model_kwargs() -> None:
|
||||
@ -14,3 +19,20 @@ def test_openai_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = OpenAIEmbeddings(foo="bar", openai_api_key="foo") # type: ignore[call-arg]
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_embed_documents_with_custom_chunk_size() -> None:
|
||||
embeddings = OpenAIEmbeddings(chunk_size=2)
|
||||
texts = ["text1", "text2", "text3", "text4"]
|
||||
custom_chunk_size = 3
|
||||
|
||||
with patch.object(embeddings.client, "create") as mock_create:
|
||||
mock_create.side_effect = [
|
||||
{"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]},
|
||||
{"data": [{"embedding": [0.5, 0.6]}, {"embedding": [0.7, 0.8]}]},
|
||||
]
|
||||
|
||||
embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
|
||||
mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params)
|
||||
mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params)
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
@ -16,6 +17,8 @@ from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_tool_call(
|
||||
raw_tool_call: dict[str, Any],
|
||||
@ -250,6 +253,14 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
return parsed_result
|
||||
|
||||
|
||||
# Common cause of ValidationError is truncated output due to max_tokens.
|
||||
_MAX_TOKENS_ERROR = (
|
||||
"Output parser received a `max_tokens` stop reason. "
|
||||
"The output is likely incomplete—please increase `max_tokens` "
|
||||
"or shorten your prompt."
|
||||
)
|
||||
|
||||
|
||||
class PydanticToolsParser(JsonOutputToolsParser):
|
||||
"""Parse tools from OpenAI response."""
|
||||
|
||||
@ -296,6 +307,14 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
||||
except (ValidationError, ValueError):
|
||||
if partial:
|
||||
continue
|
||||
has_max_tokens_stop_reason = any(
|
||||
generation.message.response_metadata.get("stop_reason")
|
||||
== "max_tokens"
|
||||
for generation in result
|
||||
if isinstance(generation, ChatGeneration)
|
||||
)
|
||||
if has_max_tokens_stop_reason:
|
||||
logger.exception(_MAX_TOKENS_ERROR)
|
||||
raise
|
||||
if self.first_tool_only:
|
||||
return pydantic_objects[0] if pydantic_objects else None
|
||||
|
@ -108,7 +108,7 @@ class Node(NamedTuple):
|
||||
|
||||
id: str
|
||||
name: str
|
||||
data: Union[type[BaseModel], RunnableType]
|
||||
data: Union[type[BaseModel], RunnableType, None]
|
||||
metadata: Optional[dict[str, Any]]
|
||||
|
||||
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
|
||||
@ -181,7 +181,7 @@ class MermaidDrawMethod(Enum):
|
||||
API = "api" # Uses Mermaid.INK API to render the graph
|
||||
|
||||
|
||||
def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str:
|
||||
def node_data_str(id: str, data: Union[type[BaseModel], RunnableType, None]) -> str:
|
||||
"""Convert the data of a node to a string.
|
||||
|
||||
Args:
|
||||
@ -193,7 +193,7 @@ def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str:
|
||||
"""
|
||||
from langchain_core.runnables.base import Runnable
|
||||
|
||||
if not is_uuid(id):
|
||||
if not is_uuid(id) or data is None:
|
||||
return id
|
||||
data_str = data.get_name() if isinstance(data, Runnable) else data.__name__
|
||||
return data_str if not data_str.startswith("Runnable") else data_str[8:]
|
||||
@ -215,8 +215,10 @@ def node_data_json(
|
||||
from langchain_core.load.serializable import to_json_not_implemented
|
||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||
|
||||
if isinstance(node.data, RunnableSerializable):
|
||||
json: dict[str, Any] = {
|
||||
if node.data is None:
|
||||
json: dict[str, Any] = {}
|
||||
elif isinstance(node.data, RunnableSerializable):
|
||||
json = {
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": node.data.lc_id(),
|
||||
@ -317,7 +319,7 @@ class Graph:
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
data: Union[type[BaseModel], RunnableType],
|
||||
data: Union[type[BaseModel], RunnableType, None],
|
||||
id: Optional[str] = None,
|
||||
*,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
|
@ -2,7 +2,7 @@ from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@ -635,3 +635,24 @@ def test_parse_with_different_pydantic_1_proper() -> None:
|
||||
forecast="Sunny",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_max_tokens_error(caplog: Any) -> None:
|
||||
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
||||
input = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "NameCollector",
|
||||
"args": {"names": ["suz", "jerm"]},
|
||||
}
|
||||
],
|
||||
response_metadata={"stop_reason": "max_tokens"},
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
_ = parser.invoke(input)
|
||||
assert any(
|
||||
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
|
||||
for record, msg in zip(caplog.records, caplog.messages)
|
||||
)
|
||||
|
@ -82,8 +82,8 @@ class AnthropicTool(TypedDict):
|
||||
"""Anthropic tool definition."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any]
|
||||
description: NotRequired[str]
|
||||
cache_control: NotRequired[dict[str, str]]
|
||||
|
||||
|
||||
@ -1675,9 +1675,10 @@ def convert_to_anthropic_tool(
|
||||
oai_formatted = convert_to_openai_tool(tool)["function"]
|
||||
anthropic_formatted = AnthropicTool(
|
||||
name=oai_formatted["name"],
|
||||
description=oai_formatted["description"],
|
||||
input_schema=oai_formatted["parameters"],
|
||||
)
|
||||
if "description" in oai_formatted:
|
||||
anthropic_formatted["description"] = oai_formatted["description"]
|
||||
return anthropic_formatted
|
||||
|
||||
|
||||
|
@ -931,3 +931,12 @@ def test_anthropic_bind_tools_tool_choice() -> None:
|
||||
assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == {
|
||||
"type": "any"
|
||||
}
|
||||
|
||||
|
||||
def test_optional_description() -> None:
|
||||
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
|
||||
|
||||
class SampleModel(BaseModel):
|
||||
sample_field: str
|
||||
|
||||
_ = llm.with_structured_output(SampleModel.model_json_schema())
|
||||
|
Loading…
Reference in New Issue
Block a user