Merge branch 'master' into pprados/06-pdfplumber

This commit is contained in:
Philippe PRADOS 2025-03-03 16:28:17 +01:00 committed by GitHub
commit bd3a24f2d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 270 additions and 92 deletions

4
.github/CODEOWNERS vendored
View File

@ -1,2 +1,2 @@
/.github/ @efriis @baskaryan @ccurme
/libs/packages.yml @efriis
/.github/ @baskaryan @ccurme
/libs/packages.yml @ccurme

View File

@ -26,4 +26,4 @@ Additional guidelines:
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in langchain.
If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
If no one reviews your PR within a few days, please @-mention one of baskaryan, eyurtsev, ccurme, vbarda, hwchase17.

View File

@ -328,7 +328,7 @@ html[data-theme=dark] .MathJax_SVG * {
}
.bd-sidebar-primary {
width: 22%; /* Adjust this value to your preference */
width: max-content; /* Adjust this value to your preference */
line-height: 1.4;
}

View File

@ -35,5 +35,5 @@ Please reference our [Review Process](review_process.mdx).
### I think my PR was closed in a way that didn't follow the review process. What should I do?
Tag `@efriis` in the PR comments referencing the portion of the review
Tag `@ccurme` in the PR comments referencing the portion of the review
process that you believe was not followed. We'll take a look!

View File

@ -1 +1,2 @@
httpx
httpx
grpcio

View File

@ -30,6 +30,7 @@ class AscendEmbeddings(Embeddings, BaseModel):
document_instruction: str = ""
use_fp16: bool = True
pooling_method: Optional[str] = "cls"
batch_size: int = 32
model: Any
tokenizer: Any
@ -119,7 +120,18 @@ class AscendEmbeddings(Embeddings, BaseModel):
)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self.encode([self.document_instruction + text for text in texts])
try:
import numpy as np
except ImportError as e:
raise ImportError(
"Unable to import numpy, please install with `pip install -U numpy`."
) from e
embedding_list = []
for i in range(0, len(texts), self.batch_size):
texts_ = texts[i : i + self.batch_size]
emb = self.encode([self.document_instruction + text for text in texts_])
embedding_list.append(emb)
return np.concatenate(embedding_list)
def embed_query(self, text: str) -> List[float]:
return self.encode([self.query_instruction + text])[0]

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import json
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Generator,
Iterator,
@ -12,7 +14,12 @@ from typing import (
Union,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
@ -126,6 +133,7 @@ class Xinference(LLM):
self,
server_url: Optional[str] = None,
model_uid: Optional[str] = None,
api_key: Optional[str] = None,
**model_kwargs: Any,
):
try:
@ -155,7 +163,13 @@ class Xinference(LLM):
if self.model_uid is None:
raise ValueError("Please provide the model UID")
self.client = RESTfulClient(server_url)
self._headers: Dict[str, str] = {}
self._cluster_authed = False
self._check_cluster_authenticated()
if api_key is not None and self._cluster_authed:
self._headers["Authorization"] = f"Bearer {api_key}"
self.client = RESTfulClient(server_url, api_key)
@property
def _llm_type(self) -> str:
@ -171,6 +185,20 @@ class Xinference(LLM):
**{"model_kwargs": self.model_kwargs},
}
def _check_cluster_authenticated(self) -> None:
url = f"{self.server_url}/v1/cluster/auth"
response = requests.get(url)
if response.status_code == 404:
self._cluster_authed = False
else:
if response.status_code != 200:
raise RuntimeError(
f"Failed to get cluster information, "
f"detail: {response.json()['detail']}"
)
response_data = response.json()
self._cluster_authed = bool(response_data["auth"])
def _call(
self,
prompt: str,
@ -305,3 +333,61 @@ class Xinference(LLM):
return GenerationChunk(text=token)
else:
raise TypeError("stream_response type error!")
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
generate_config = kwargs.get("generate_config", {})
generate_config = {**self.model_kwargs, **generate_config}
if stop:
generate_config["stop"] = stop
async for stream_resp in self._acreate_generate_stream(prompt, generate_config):
if stream_resp:
chunk = self._stream_response_to_generation_chunk(stream_resp)
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk
async def _acreate_generate_stream(
self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None
) -> AsyncIterator[str]:
request_body: Dict[str, Any] = {"model": self.model_uid, "prompt": prompt}
if generate_config is not None:
for key, value in generate_config.items():
request_body[key] = value
stream = bool(generate_config and generate_config.get("stream"))
async with aiohttp.ClientSession() as session:
async with session.post(
url=f"{self.server_url}/v1/completions",
json=request_body,
) as response:
if response.status != 200:
if response.status == 404:
raise FileNotFoundError(
"astream call failed with status code 404."
)
else:
optional_detail = response.text
raise ValueError(
f"astream call failed with status code {response.status}."
f" Details: {optional_detail}"
)
async for line in response.content:
if not stream:
yield json.loads(line)
else:
json_str = line.decode("utf-8")
if line.startswith(b"data:"):
json_str = json_str[len(b"data:") :].strip()
if not json_str:
continue
yield json.loads(json_str)

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, TextIO, cast
from langchain_core.callbacks import BaseCallbackHandler
@ -30,7 +31,7 @@ class FileCallbackHandler(BaseCallbackHandler):
mode: The mode to open the file in. Defaults to "a".
color: The color to use for the text. Defaults to None.
"""
self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) # noqa: SIM115
self.file = cast(TextIO, Path(filename).open(mode, encoding="utf-8")) # noqa: SIM115
self.color = color
def __del__(self) -> None:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import contextlib
import mimetypes
from io import BufferedReader, BytesIO
from pathlib import PurePath
from pathlib import Path, PurePath
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from pydantic import ConfigDict, Field, field_validator, model_validator
@ -151,8 +151,7 @@ class Blob(BaseMedia):
def as_string(self) -> str:
"""Read data as a string."""
if self.data is None and self.path:
with open(str(self.path), encoding=self.encoding) as f:
return f.read()
return Path(self.path).read_text(encoding=self.encoding)
elif isinstance(self.data, bytes):
return self.data.decode(self.encoding)
elif isinstance(self.data, str):
@ -168,8 +167,7 @@ class Blob(BaseMedia):
elif isinstance(self.data, str):
return self.data.encode(self.encoding)
elif self.data is None and self.path:
with open(str(self.path), "rb") as f:
return f.read()
return Path(self.path).read_bytes()
else:
msg = f"Unable to get bytes for blob {self}"
raise ValueError(msg)
@ -180,7 +178,7 @@ class Blob(BaseMedia):
if isinstance(self.data, bytes):
yield BytesIO(self.data)
elif self.data is None and self.path:
with open(str(self.path), "rb") as f:
with Path(self.path).open("rb") as f:
yield f
else:
msg = f"Unable to convert blob {self}"

View File

@ -1402,7 +1402,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
llm.save(file_path="path/llm.yaml")
"""
# Convert file to Path object.
save_path = Path(file_path) if isinstance(file_path, str) else file_path
save_path = Path(file_path)
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
@ -1411,10 +1411,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
prompt_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
with save_path.open("w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f:
with save_path.open("w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
msg = f"{save_path} must be json or yaml"

View File

@ -368,16 +368,16 @@ class BasePromptTemplate(
raise NotImplementedError(msg)
# Convert file to Path object.
save_path = Path(file_path) if isinstance(file_path, str) else file_path
save_path = Path(file_path)
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(file_path, "w") as f:
with save_path.open("w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f:
with save_path.open("w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
msg = f"{save_path} must be json or yaml"

View File

@ -3,6 +3,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
TYPE_CHECKING,
Annotated,
@ -48,7 +49,6 @@ from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path
class BaseMessagePromptTemplate(Serializable, ABC):
@ -599,8 +599,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
Returns:
A new instance of this class.
"""
with open(str(template_file)) as f:
template = f.read()
template = Path(template_file).read_text()
return cls.from_template(template, input_variables=input_variables, **kwargs)
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:

View File

@ -53,8 +53,7 @@ def _load_template(var_name: str, config: dict) -> dict:
template_path = Path(config.pop(f"{var_name}_path"))
# Load the template.
if template_path.suffix == ".txt":
with open(template_path) as f:
template = f.read()
template = template_path.read_text()
else:
raise ValueError
# Set the template variable to the extracted variable.
@ -67,10 +66,11 @@ def _load_examples(config: dict) -> dict:
if isinstance(config["examples"], list):
pass
elif isinstance(config["examples"], str):
with open(config["examples"]) as f:
if config["examples"].endswith(".json"):
path = Path(config["examples"])
with path.open() as f:
if path.suffix == ".json":
examples = json.load(f)
elif config["examples"].endswith((".yaml", ".yml")):
elif path.suffix in {".yaml", ".yml"}:
examples = yaml.safe_load(f)
else:
msg = "Invalid file format. Only json or yaml formats are supported."
@ -168,13 +168,13 @@ def _load_prompt_from_file(
) -> BasePromptTemplate:
"""Load prompt from file."""
# Convert file to a Path object.
file_path = Path(file) if isinstance(file, str) else file
file_path = Path(file)
# Load from either json or yaml.
if file_path.suffix == ".json":
with open(file_path, encoding=encoding) as f:
with file_path.open(encoding=encoding) as f:
config = json.load(f)
elif file_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, encoding=encoding) as f:
with file_path.open(encoding=encoding) as f:
config = yaml.safe_load(f)
else:
msg = f"Got unsupported file type {file_path.suffix}"

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
from pydantic import BaseModel, model_validator
@ -17,8 +18,6 @@ from langchain_core.prompts.string import (
)
if TYPE_CHECKING:
from pathlib import Path
from langchain_core.runnables.config import RunnableConfig
@ -238,8 +237,7 @@ class PromptTemplate(StringPromptTemplate):
Returns:
The prompt loaded from the file.
"""
with open(str(template_file), encoding=encoding) as f:
template = f.read()
template = Path(template_file).read_text(encoding=encoding)
if input_variables:
warnings.warn(
"`input_variables' is deprecated and ignored.",

View File

@ -1644,8 +1644,13 @@ class Runnable(Generic[Input, Output], ABC):
.. code-block:: python
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables import RunnableLambda, Runnable
from datetime import datetime, timezone
import time
import asyncio
def format_t(timestamp: float) -> str:
return datetime.fromtimestamp(timestamp, tz=timezone.utc).isoformat()
async def test_runnable(time_to_sleep : int):
print(f"Runnable[{time_to_sleep}s]: starts at {format_t(time.time())}")
@ -1653,12 +1658,12 @@ class Runnable(Generic[Input, Output], ABC):
print(f"Runnable[{time_to_sleep}s]: ends at {format_t(time.time())}")
async def fn_start(run_obj : Runnable):
print(f"on start callback starts at {format_t(time.time())}
print(f"on start callback starts at {format_t(time.time())}")
await asyncio.sleep(3)
print(f"on start callback ends at {format_t(time.time())}")
async def fn_end(run_obj : Runnable):
print(f"on end callback starts at {format_t(time.time())}
print(f"on end callback starts at {format_t(time.time())}")
await asyncio.sleep(2)
print(f"on end callback ends at {format_t(time.time())}")
@ -1671,18 +1676,18 @@ class Runnable(Generic[Input, Output], ABC):
asyncio.run(concurrent_runs())
Result:
on start callback starts at 2024-05-16T14:20:29.637053+00:00
on start callback starts at 2024-05-16T14:20:29.637150+00:00
on start callback ends at 2024-05-16T14:20:32.638305+00:00
on start callback ends at 2024-05-16T14:20:32.638383+00:00
Runnable[3s]: starts at 2024-05-16T14:20:32.638849+00:00
Runnable[5s]: starts at 2024-05-16T14:20:32.638999+00:00
Runnable[3s]: ends at 2024-05-16T14:20:35.640016+00:00
on end callback starts at 2024-05-16T14:20:35.640534+00:00
Runnable[5s]: ends at 2024-05-16T14:20:37.640169+00:00
on end callback starts at 2024-05-16T14:20:37.640574+00:00
on end callback ends at 2024-05-16T14:20:37.640654+00:00
on end callback ends at 2024-05-16T14:20:39.641751+00:00
on start callback starts at 2025-03-01T07:05:22.875378+00:00
on start callback starts at 2025-03-01T07:05:22.875495+00:00
on start callback ends at 2025-03-01T07:05:25.878862+00:00
on start callback ends at 2025-03-01T07:05:25.878947+00:00
Runnable[2s]: starts at 2025-03-01T07:05:25.879392+00:00
Runnable[3s]: starts at 2025-03-01T07:05:25.879804+00:00
Runnable[2s]: ends at 2025-03-01T07:05:27.881998+00:00
on end callback starts at 2025-03-01T07:05:27.882360+00:00
Runnable[3s]: ends at 2025-03-01T07:05:28.881737+00:00
on end callback starts at 2025-03-01T07:05:28.882428+00:00
on end callback ends at 2025-03-01T07:05:29.883893+00:00
on end callback ends at 2025-03-01T07:05:30.884831+00:00
"""
from langchain_core.tracers.root_listeners import AsyncRootListenersTracer

View File

@ -2,6 +2,7 @@ import asyncio
import base64
import re
from dataclasses import asdict
from pathlib import Path
from typing import Literal, Optional
from langchain_core.runnables.graph import (
@ -290,13 +291,9 @@ async def _render_mermaid_using_pyppeteer(
img_bytes = await page.screenshot({"fullPage": False})
await browser.close()
def write_to_file(path: str, bytes: bytes) -> None:
with open(path, "wb") as file:
file.write(bytes)
if output_file_path is not None:
await asyncio.get_event_loop().run_in_executor(
None, write_to_file, output_file_path, img_bytes
None, Path(output_file_path).write_bytes, img_bytes
)
return img_bytes
@ -337,8 +334,7 @@ def _render_mermaid_using_api(
if response.status_code == 200:
img_bytes = response.content
if output_file_path is not None:
with open(output_file_path, "wb") as file:
file.write(response.content)
Path(output_file_path).write_bytes(response.content)
return img_bytes
else:

View File

@ -77,7 +77,7 @@ target-version = "py39"
[tool.ruff.lint]
select = [ "ANN", "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TC", "TID", "TRY", "UP", "W", "YTT",]
select = [ "ANN", "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "PTH", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TC", "TID", "TRY", "UP", "W", "YTT",]
ignore = [ "ANN401", "COM812", "UP007", "S110", "S112", "TC001", "TC002", "TC003"]
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
flake8-annotations.allow-star-arg-any = true

View File

@ -354,7 +354,7 @@ def test_prompt_from_file_with_partial_variables() -> None:
template = "This is a {foo} test {bar}."
partial_variables = {"bar": "baz"}
# when
with mock.patch("builtins.open", mock.mock_open(read_data=template)):
with mock.patch("pathlib.Path.open", mock.mock_open(read_data=template)):
prompt = PromptTemplate.from_file(
"mock_file_name", partial_variables=partial_variables
)

View File

@ -1,20 +1,17 @@
import concurrent.futures
import glob
import importlib
import subprocess
from pathlib import Path
def test_importable_all() -> None:
for path in glob.glob("../core/langchain_core/*"):
relative_path = Path(path).parts[-1]
if relative_path.endswith(".typed"):
continue
module_name = relative_path.split(".")[0]
module = importlib.import_module("langchain_core." + module_name)
all_ = getattr(module, "__all__", [])
for cls_ in all_:
getattr(module, cls_)
for path in Path("../core/langchain_core/").glob("*"):
module_name = path.stem
if not module_name.startswith(".") and path.suffix != ".typed":
module = importlib.import_module("langchain_core." + module_name)
all_ = getattr(module, "__all__", [])
for cls_ in all_:
getattr(module, cls_)
def try_to_import(module_name: str) -> tuple[int, str]:
@ -37,12 +34,10 @@ def test_importable_all_via_subprocess() -> None:
for one sequence of imports but not another.
"""
module_names = []
for path in glob.glob("../core/langchain_core/*"):
relative_path = Path(path).parts[-1]
if relative_path.endswith(".typed"):
continue
module_name = relative_path.split(".")[0]
module_names.append(module_name)
for path in Path("../core/langchain_core/").glob("*"):
module_name = path.stem
if not module_name.startswith(".") and path.suffix != ".typed":
module_names.append(module_name)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [

View File

@ -26,6 +26,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
@ -83,6 +84,15 @@ _message_type_lookups = {
}
class AnthropicTool(TypedDict):
"""Anthropic tool definition."""
name: str
description: str
input_schema: Dict[str, Any]
cache_control: NotRequired[Dict[str, str]]
def _format_image(image_url: str) -> Dict:
"""
Formats an image of format data:image/jpeg;base64,{b64_string}
@ -900,6 +910,8 @@ class ChatAnthropic(BaseChatModel):
llm_output = {
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
}
if "model" in llm_output and "model_name" not in llm_output:
llm_output["model_name"] = llm_output["model"]
if (
len(content) == 1
and content[0]["type"] == "text"
@ -952,6 +964,31 @@ class ChatAnthropic(BaseChatModel):
data = await self._async_client.messages.create(**payload)
return self._format_output(data, **kwargs)
def _get_llm_for_structured_output_when_thinking_is_enabled(
self,
schema: Union[Dict, type],
formatted_tool: AnthropicTool,
) -> Runnable[LanguageModelInput, BaseMessage]:
thinking_admonition = (
"Anthropic structured output relies on forced tool calling, "
"which is not supported when `thinking` is enabled. This method will raise "
"langchain_core.exceptions.OutputParserException if tool calls are not "
"generated. Consider disabling `thinking` or adjust your prompt to ensure "
"the tool is called."
)
warnings.warn(thinking_admonition)
llm = self.bind_tools(
[schema],
structured_output_format={"kwargs": {}, "schema": formatted_tool},
)
def _raise_if_no_tool_calls(message: AIMessage) -> AIMessage:
if not message.tool_calls:
raise OutputParserException(thinking_admonition)
return message
return llm | _raise_if_no_tool_calls
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
@ -1249,11 +1286,17 @@ class ChatAnthropic(BaseChatModel):
""" # noqa: E501
formatted_tool = convert_to_anthropic_tool(schema)
tool_name = formatted_tool["name"]
llm = self.bind_tools(
[schema],
tool_choice=tool_name,
structured_output_format={"kwargs": {}, "schema": formatted_tool},
)
if self.thinking is not None and self.thinking.get("type") == "enabled":
llm = self._get_llm_for_structured_output_when_thinking_is_enabled(
schema, formatted_tool
)
else:
llm = self.bind_tools(
[schema],
tool_choice=tool_name,
structured_output_format={"kwargs": {}, "schema": formatted_tool},
)
if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
@ -1356,15 +1399,6 @@ class ChatAnthropic(BaseChatModel):
return response.input_tokens
class AnthropicTool(TypedDict):
"""Anthropic tool definition."""
name: str
description: str
input_schema: Dict[str, Any]
cache_control: NotRequired[Dict[str, str]]
def convert_to_anthropic_tool(
tool: Union[Dict[str, Any], Type, Callable, BaseTool],
) -> AnthropicTool:
@ -1445,9 +1479,14 @@ def _make_message_chunk_from_anthropic_event(
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
if event.type == "message_start" and stream_usage:
usage_metadata = _create_usage_metadata(event.message.usage)
if hasattr(event.message, "model"):
response_metadata = {"model_name": event.message.model}
else:
response_metadata = {}
message_chunk = AIMessageChunk(
content="" if coerce_content_to_string else [],
usage_metadata=usage_metadata,
response_metadata=response_metadata,
)
elif (
event.type == "content_block_start"

View File

@ -6,7 +6,9 @@ from typing import List, Optional
import pytest
import requests
from anthropic import BadRequestError
from langchain_core.callbacks import CallbackManager
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@ -35,6 +37,7 @@ def test_stream() -> None:
full: Optional[BaseMessageChunk] = None
chunks_with_input_token_counts = 0
chunks_with_output_token_counts = 0
chunks_with_model_name = 0
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
full = token if full is None else full + token
@ -44,12 +47,14 @@ def test_stream() -> None:
chunks_with_input_token_counts += 1
elif token.usage_metadata.get("output_tokens"):
chunks_with_output_token_counts += 1
chunks_with_model_name += int("model_name" in token.response_metadata)
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with input or output token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
assert chunks_with_model_name == 1
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
@ -62,6 +67,7 @@ def test_stream() -> None:
)
assert "stop_reason" in full.response_metadata
assert "stop_sequence" in full.response_metadata
assert "model_name" in full.response_metadata
async def test_astream() -> None:
@ -219,6 +225,7 @@ async def test_ainvoke() -> None:
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
assert "model_name" in result.response_metadata
def test_invoke() -> None:
@ -725,3 +732,39 @@ def test_redacted_thinking() -> None:
assert set(block.keys()) == {"type", "data", "index"}
assert block["data"] and isinstance(block["data"], str)
assert stream_has_reasoning
def test_structured_output_thinking_enabled() -> None:
llm = ChatAnthropic(
model="claude-3-7-sonnet-latest",
max_tokens=5_000,
thinking={"type": "enabled", "budget_tokens": 2_000},
)
with pytest.warns(match="structured output"):
structured_llm = llm.with_structured_output(GenerateUsername)
query = "Generate a username for Sally with green hair"
response = structured_llm.invoke(query)
assert isinstance(response, GenerateUsername)
with pytest.raises(OutputParserException):
structured_llm.invoke("Hello")
# Test streaming
for chunk in structured_llm.stream(query):
assert isinstance(chunk, GenerateUsername)
def test_structured_output_thinking_force_tool_use() -> None:
# Structured output currently relies on forced tool use, which is not supported
# when `thinking` is enabled. When this test fails, it means that the feature
# is supported and the workarounds in `with_structured_output` should be removed.
llm = ChatAnthropic(
model="claude-3-7-sonnet-latest",
max_tokens=5_000,
thinking={"type": "enabled", "budget_tokens": 2_000},
).bind_tools(
[GenerateUsername],
tool_choice="GenerateUsername",
)
with pytest.raises(BadRequestError):
llm.invoke("Generate a username for Sally with green hair")

View File

@ -579,7 +579,11 @@ class ChatMistralAI(BaseChatModel):
)
generations.append(gen)
llm_output = {"token_usage": token_usage, "model": self.model}
llm_output = {
"token_usage": token_usage,
"model_name": self.model,
"model": self.model, # Backwards compatability
}
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(

View File

@ -87,6 +87,7 @@ async def test_ainvoke() -> None:
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
assert "model_name" in result.response_metadata
def test_invoke() -> None: