Merge branch 'master' into pprados/06-pdfplumber

This commit is contained in:
Philippe PRADOS 2025-03-03 16:28:17 +01:00 committed by GitHub
commit bd3a24f2d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 270 additions and 92 deletions

4
.github/CODEOWNERS vendored
View File

@ -1,2 +1,2 @@
/.github/ @efriis @baskaryan @ccurme /.github/ @baskaryan @ccurme
/libs/packages.yml @efriis /libs/packages.yml @ccurme

View File

@ -26,4 +26,4 @@ Additional guidelines:
- Changes should be backwards compatible. - Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in langchain. - If you are adding something to community, do not re-import it in langchain.
If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. If no one reviews your PR within a few days, please @-mention one of baskaryan, eyurtsev, ccurme, vbarda, hwchase17.

View File

@ -328,7 +328,7 @@ html[data-theme=dark] .MathJax_SVG * {
} }
.bd-sidebar-primary { .bd-sidebar-primary {
width: 22%; /* Adjust this value to your preference */ width: max-content; /* Adjust this value to your preference */
line-height: 1.4; line-height: 1.4;
} }

View File

@ -35,5 +35,5 @@ Please reference our [Review Process](review_process.mdx).
### I think my PR was closed in a way that didn't follow the review process. What should I do? ### I think my PR was closed in a way that didn't follow the review process. What should I do?
Tag `@efriis` in the PR comments referencing the portion of the review Tag `@ccurme` in the PR comments referencing the portion of the review
process that you believe was not followed. We'll take a look! process that you believe was not followed. We'll take a look!

View File

@ -1 +1,2 @@
httpx httpx
grpcio

View File

@ -30,6 +30,7 @@ class AscendEmbeddings(Embeddings, BaseModel):
document_instruction: str = "" document_instruction: str = ""
use_fp16: bool = True use_fp16: bool = True
pooling_method: Optional[str] = "cls" pooling_method: Optional[str] = "cls"
batch_size: int = 32
model: Any model: Any
tokenizer: Any tokenizer: Any
@ -119,7 +120,18 @@ class AscendEmbeddings(Embeddings, BaseModel):
) )
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self.encode([self.document_instruction + text for text in texts]) try:
import numpy as np
except ImportError as e:
raise ImportError(
"Unable to import numpy, please install with `pip install -U numpy`."
) from e
embedding_list = []
for i in range(0, len(texts), self.batch_size):
texts_ = texts[i : i + self.batch_size]
emb = self.encode([self.document_instruction + text for text in texts_])
embedding_list.append(emb)
return np.concatenate(embedding_list)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
return self.encode([self.query_instruction + text])[0] return self.encode([self.query_instruction + text])[0]

View File

@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
import json
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator,
Dict, Dict,
Generator, Generator,
Iterator, Iterator,
@ -12,7 +14,12 @@ from typing import (
Union, Union,
) )
from langchain_core.callbacks import CallbackManagerForLLMRun import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk from langchain_core.outputs import GenerationChunk
@ -126,6 +133,7 @@ class Xinference(LLM):
self, self,
server_url: Optional[str] = None, server_url: Optional[str] = None,
model_uid: Optional[str] = None, model_uid: Optional[str] = None,
api_key: Optional[str] = None,
**model_kwargs: Any, **model_kwargs: Any,
): ):
try: try:
@ -155,7 +163,13 @@ class Xinference(LLM):
if self.model_uid is None: if self.model_uid is None:
raise ValueError("Please provide the model UID") raise ValueError("Please provide the model UID")
self.client = RESTfulClient(server_url) self._headers: Dict[str, str] = {}
self._cluster_authed = False
self._check_cluster_authenticated()
if api_key is not None and self._cluster_authed:
self._headers["Authorization"] = f"Bearer {api_key}"
self.client = RESTfulClient(server_url, api_key)
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
@ -171,6 +185,20 @@ class Xinference(LLM):
**{"model_kwargs": self.model_kwargs}, **{"model_kwargs": self.model_kwargs},
} }
def _check_cluster_authenticated(self) -> None:
url = f"{self.server_url}/v1/cluster/auth"
response = requests.get(url)
if response.status_code == 404:
self._cluster_authed = False
else:
if response.status_code != 200:
raise RuntimeError(
f"Failed to get cluster information, "
f"detail: {response.json()['detail']}"
)
response_data = response.json()
self._cluster_authed = bool(response_data["auth"])
def _call( def _call(
self, self,
prompt: str, prompt: str,
@ -305,3 +333,61 @@ class Xinference(LLM):
return GenerationChunk(text=token) return GenerationChunk(text=token)
else: else:
raise TypeError("stream_response type error!") raise TypeError("stream_response type error!")
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
generate_config = kwargs.get("generate_config", {})
generate_config = {**self.model_kwargs, **generate_config}
if stop:
generate_config["stop"] = stop
async for stream_resp in self._acreate_generate_stream(prompt, generate_config):
if stream_resp:
chunk = self._stream_response_to_generation_chunk(stream_resp)
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk
async def _acreate_generate_stream(
self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None
) -> AsyncIterator[str]:
request_body: Dict[str, Any] = {"model": self.model_uid, "prompt": prompt}
if generate_config is not None:
for key, value in generate_config.items():
request_body[key] = value
stream = bool(generate_config and generate_config.get("stream"))
async with aiohttp.ClientSession() as session:
async with session.post(
url=f"{self.server_url}/v1/completions",
json=request_body,
) as response:
if response.status != 200:
if response.status == 404:
raise FileNotFoundError(
"astream call failed with status code 404."
)
else:
optional_detail = response.text
raise ValueError(
f"astream call failed with status code {response.status}."
f" Details: {optional_detail}"
)
async for line in response.content:
if not stream:
yield json.loads(line)
else:
json_str = line.decode("utf-8")
if line.startswith(b"data:"):
json_str = json_str[len(b"data:") :].strip()
if not json_str:
continue
yield json.loads(json_str)

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, TextIO, cast from typing import TYPE_CHECKING, Any, Optional, TextIO, cast
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
@ -30,7 +31,7 @@ class FileCallbackHandler(BaseCallbackHandler):
mode: The mode to open the file in. Defaults to "a". mode: The mode to open the file in. Defaults to "a".
color: The color to use for the text. Defaults to None. color: The color to use for the text. Defaults to None.
""" """
self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) # noqa: SIM115 self.file = cast(TextIO, Path(filename).open(mode, encoding="utf-8")) # noqa: SIM115
self.color = color self.color = color
def __del__(self) -> None: def __del__(self) -> None:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import contextlib import contextlib
import mimetypes import mimetypes
from io import BufferedReader, BytesIO from io import BufferedReader, BytesIO
from pathlib import PurePath from pathlib import Path, PurePath
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic import ConfigDict, Field, field_validator, model_validator
@ -151,8 +151,7 @@ class Blob(BaseMedia):
def as_string(self) -> str: def as_string(self) -> str:
"""Read data as a string.""" """Read data as a string."""
if self.data is None and self.path: if self.data is None and self.path:
with open(str(self.path), encoding=self.encoding) as f: return Path(self.path).read_text(encoding=self.encoding)
return f.read()
elif isinstance(self.data, bytes): elif isinstance(self.data, bytes):
return self.data.decode(self.encoding) return self.data.decode(self.encoding)
elif isinstance(self.data, str): elif isinstance(self.data, str):
@ -168,8 +167,7 @@ class Blob(BaseMedia):
elif isinstance(self.data, str): elif isinstance(self.data, str):
return self.data.encode(self.encoding) return self.data.encode(self.encoding)
elif self.data is None and self.path: elif self.data is None and self.path:
with open(str(self.path), "rb") as f: return Path(self.path).read_bytes()
return f.read()
else: else:
msg = f"Unable to get bytes for blob {self}" msg = f"Unable to get bytes for blob {self}"
raise ValueError(msg) raise ValueError(msg)
@ -180,7 +178,7 @@ class Blob(BaseMedia):
if isinstance(self.data, bytes): if isinstance(self.data, bytes):
yield BytesIO(self.data) yield BytesIO(self.data)
elif self.data is None and self.path: elif self.data is None and self.path:
with open(str(self.path), "rb") as f: with Path(self.path).open("rb") as f:
yield f yield f
else: else:
msg = f"Unable to convert blob {self}" msg = f"Unable to convert blob {self}"

View File

@ -1402,7 +1402,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
llm.save(file_path="path/llm.yaml") llm.save(file_path="path/llm.yaml")
""" """
# Convert file to Path object. # Convert file to Path object.
save_path = Path(file_path) if isinstance(file_path, str) else file_path save_path = Path(file_path)
directory_path = save_path.parent directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True) directory_path.mkdir(parents=True, exist_ok=True)
@ -1411,10 +1411,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
prompt_dict = self.dict() prompt_dict = self.dict()
if save_path.suffix == ".json": if save_path.suffix == ".json":
with open(file_path, "w") as f: with save_path.open("w") as f:
json.dump(prompt_dict, f, indent=4) json.dump(prompt_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")): elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f: with save_path.open("w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False) yaml.dump(prompt_dict, f, default_flow_style=False)
else: else:
msg = f"{save_path} must be json or yaml" msg = f"{save_path} must be json or yaml"

View File

@ -368,16 +368,16 @@ class BasePromptTemplate(
raise NotImplementedError(msg) raise NotImplementedError(msg)
# Convert file to Path object. # Convert file to Path object.
save_path = Path(file_path) if isinstance(file_path, str) else file_path save_path = Path(file_path)
directory_path = save_path.parent directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True) directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json": if save_path.suffix == ".json":
with open(file_path, "w") as f: with save_path.open("w") as f:
json.dump(prompt_dict, f, indent=4) json.dump(prompt_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")): elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f: with save_path.open("w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False) yaml.dump(prompt_dict, f, default_flow_style=False)
else: else:
msg = f"{save_path} must be json or yaml" msg = f"{save_path} must be json or yaml"

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Annotated, Annotated,
@ -48,7 +49,6 @@ from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
from pathlib import Path
class BaseMessagePromptTemplate(Serializable, ABC): class BaseMessagePromptTemplate(Serializable, ABC):
@ -599,8 +599,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
Returns: Returns:
A new instance of this class. A new instance of this class.
""" """
with open(str(template_file)) as f: template = Path(template_file).read_text()
template = f.read()
return cls.from_template(template, input_variables=input_variables, **kwargs) return cls.from_template(template, input_variables=input_variables, **kwargs)
def format_messages(self, **kwargs: Any) -> list[BaseMessage]: def format_messages(self, **kwargs: Any) -> list[BaseMessage]:

View File

@ -53,8 +53,7 @@ def _load_template(var_name: str, config: dict) -> dict:
template_path = Path(config.pop(f"{var_name}_path")) template_path = Path(config.pop(f"{var_name}_path"))
# Load the template. # Load the template.
if template_path.suffix == ".txt": if template_path.suffix == ".txt":
with open(template_path) as f: template = template_path.read_text()
template = f.read()
else: else:
raise ValueError raise ValueError
# Set the template variable to the extracted variable. # Set the template variable to the extracted variable.
@ -67,10 +66,11 @@ def _load_examples(config: dict) -> dict:
if isinstance(config["examples"], list): if isinstance(config["examples"], list):
pass pass
elif isinstance(config["examples"], str): elif isinstance(config["examples"], str):
with open(config["examples"]) as f: path = Path(config["examples"])
if config["examples"].endswith(".json"): with path.open() as f:
if path.suffix == ".json":
examples = json.load(f) examples = json.load(f)
elif config["examples"].endswith((".yaml", ".yml")): elif path.suffix in {".yaml", ".yml"}:
examples = yaml.safe_load(f) examples = yaml.safe_load(f)
else: else:
msg = "Invalid file format. Only json or yaml formats are supported." msg = "Invalid file format. Only json or yaml formats are supported."
@ -168,13 +168,13 @@ def _load_prompt_from_file(
) -> BasePromptTemplate: ) -> BasePromptTemplate:
"""Load prompt from file.""" """Load prompt from file."""
# Convert file to a Path object. # Convert file to a Path object.
file_path = Path(file) if isinstance(file, str) else file file_path = Path(file)
# Load from either json or yaml. # Load from either json or yaml.
if file_path.suffix == ".json": if file_path.suffix == ".json":
with open(file_path, encoding=encoding) as f: with file_path.open(encoding=encoding) as f:
config = json.load(f) config = json.load(f)
elif file_path.suffix.endswith((".yaml", ".yml")): elif file_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, encoding=encoding) as f: with file_path.open(encoding=encoding) as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
else: else:
msg = f"Got unsupported file type {file_path.suffix}" msg = f"Got unsupported file type {file_path.suffix}"

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -17,8 +18,6 @@ from langchain_core.prompts.string import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from pathlib import Path
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
@ -238,8 +237,7 @@ class PromptTemplate(StringPromptTemplate):
Returns: Returns:
The prompt loaded from the file. The prompt loaded from the file.
""" """
with open(str(template_file), encoding=encoding) as f: template = Path(template_file).read_text(encoding=encoding)
template = f.read()
if input_variables: if input_variables:
warnings.warn( warnings.warn(
"`input_variables' is deprecated and ignored.", "`input_variables' is deprecated and ignored.",

View File

@ -1644,8 +1644,13 @@ class Runnable(Generic[Input, Output], ABC):
.. code-block:: python .. code-block:: python
from langchain_core.runnables import RunnableLambda from langchain_core.runnables import RunnableLambda, Runnable
from datetime import datetime, timezone
import time import time
import asyncio
def format_t(timestamp: float) -> str:
return datetime.fromtimestamp(timestamp, tz=timezone.utc).isoformat()
async def test_runnable(time_to_sleep : int): async def test_runnable(time_to_sleep : int):
print(f"Runnable[{time_to_sleep}s]: starts at {format_t(time.time())}") print(f"Runnable[{time_to_sleep}s]: starts at {format_t(time.time())}")
@ -1653,12 +1658,12 @@ class Runnable(Generic[Input, Output], ABC):
print(f"Runnable[{time_to_sleep}s]: ends at {format_t(time.time())}") print(f"Runnable[{time_to_sleep}s]: ends at {format_t(time.time())}")
async def fn_start(run_obj : Runnable): async def fn_start(run_obj : Runnable):
print(f"on start callback starts at {format_t(time.time())} print(f"on start callback starts at {format_t(time.time())}")
await asyncio.sleep(3) await asyncio.sleep(3)
print(f"on start callback ends at {format_t(time.time())}") print(f"on start callback ends at {format_t(time.time())}")
async def fn_end(run_obj : Runnable): async def fn_end(run_obj : Runnable):
print(f"on end callback starts at {format_t(time.time())} print(f"on end callback starts at {format_t(time.time())}")
await asyncio.sleep(2) await asyncio.sleep(2)
print(f"on end callback ends at {format_t(time.time())}") print(f"on end callback ends at {format_t(time.time())}")
@ -1671,18 +1676,18 @@ class Runnable(Generic[Input, Output], ABC):
asyncio.run(concurrent_runs()) asyncio.run(concurrent_runs())
Result: Result:
on start callback starts at 2024-05-16T14:20:29.637053+00:00 on start callback starts at 2025-03-01T07:05:22.875378+00:00
on start callback starts at 2024-05-16T14:20:29.637150+00:00 on start callback starts at 2025-03-01T07:05:22.875495+00:00
on start callback ends at 2024-05-16T14:20:32.638305+00:00 on start callback ends at 2025-03-01T07:05:25.878862+00:00
on start callback ends at 2024-05-16T14:20:32.638383+00:00 on start callback ends at 2025-03-01T07:05:25.878947+00:00
Runnable[3s]: starts at 2024-05-16T14:20:32.638849+00:00 Runnable[2s]: starts at 2025-03-01T07:05:25.879392+00:00
Runnable[5s]: starts at 2024-05-16T14:20:32.638999+00:00 Runnable[3s]: starts at 2025-03-01T07:05:25.879804+00:00
Runnable[3s]: ends at 2024-05-16T14:20:35.640016+00:00 Runnable[2s]: ends at 2025-03-01T07:05:27.881998+00:00
on end callback starts at 2024-05-16T14:20:35.640534+00:00 on end callback starts at 2025-03-01T07:05:27.882360+00:00
Runnable[5s]: ends at 2024-05-16T14:20:37.640169+00:00 Runnable[3s]: ends at 2025-03-01T07:05:28.881737+00:00
on end callback starts at 2024-05-16T14:20:37.640574+00:00 on end callback starts at 2025-03-01T07:05:28.882428+00:00
on end callback ends at 2024-05-16T14:20:37.640654+00:00 on end callback ends at 2025-03-01T07:05:29.883893+00:00
on end callback ends at 2024-05-16T14:20:39.641751+00:00 on end callback ends at 2025-03-01T07:05:30.884831+00:00
""" """
from langchain_core.tracers.root_listeners import AsyncRootListenersTracer from langchain_core.tracers.root_listeners import AsyncRootListenersTracer

View File

@ -2,6 +2,7 @@ import asyncio
import base64 import base64
import re import re
from dataclasses import asdict from dataclasses import asdict
from pathlib import Path
from typing import Literal, Optional from typing import Literal, Optional
from langchain_core.runnables.graph import ( from langchain_core.runnables.graph import (
@ -290,13 +291,9 @@ async def _render_mermaid_using_pyppeteer(
img_bytes = await page.screenshot({"fullPage": False}) img_bytes = await page.screenshot({"fullPage": False})
await browser.close() await browser.close()
def write_to_file(path: str, bytes: bytes) -> None:
with open(path, "wb") as file:
file.write(bytes)
if output_file_path is not None: if output_file_path is not None:
await asyncio.get_event_loop().run_in_executor( await asyncio.get_event_loop().run_in_executor(
None, write_to_file, output_file_path, img_bytes None, Path(output_file_path).write_bytes, img_bytes
) )
return img_bytes return img_bytes
@ -337,8 +334,7 @@ def _render_mermaid_using_api(
if response.status_code == 200: if response.status_code == 200:
img_bytes = response.content img_bytes = response.content
if output_file_path is not None: if output_file_path is not None:
with open(output_file_path, "wb") as file: Path(output_file_path).write_bytes(response.content)
file.write(response.content)
return img_bytes return img_bytes
else: else:

View File

@ -77,7 +77,7 @@ target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = [ "ANN", "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TC", "TID", "TRY", "UP", "W", "YTT",] select = [ "ANN", "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "PTH", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TC", "TID", "TRY", "UP", "W", "YTT",]
ignore = [ "ANN401", "COM812", "UP007", "S110", "S112", "TC001", "TC002", "TC003"] ignore = [ "ANN401", "COM812", "UP007", "S110", "S112", "TC001", "TC002", "TC003"]
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"] flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
flake8-annotations.allow-star-arg-any = true flake8-annotations.allow-star-arg-any = true

View File

@ -354,7 +354,7 @@ def test_prompt_from_file_with_partial_variables() -> None:
template = "This is a {foo} test {bar}." template = "This is a {foo} test {bar}."
partial_variables = {"bar": "baz"} partial_variables = {"bar": "baz"}
# when # when
with mock.patch("builtins.open", mock.mock_open(read_data=template)): with mock.patch("pathlib.Path.open", mock.mock_open(read_data=template)):
prompt = PromptTemplate.from_file( prompt = PromptTemplate.from_file(
"mock_file_name", partial_variables=partial_variables "mock_file_name", partial_variables=partial_variables
) )

View File

@ -1,20 +1,17 @@
import concurrent.futures import concurrent.futures
import glob
import importlib import importlib
import subprocess import subprocess
from pathlib import Path from pathlib import Path
def test_importable_all() -> None: def test_importable_all() -> None:
for path in glob.glob("../core/langchain_core/*"): for path in Path("../core/langchain_core/").glob("*"):
relative_path = Path(path).parts[-1] module_name = path.stem
if relative_path.endswith(".typed"): if not module_name.startswith(".") and path.suffix != ".typed":
continue module = importlib.import_module("langchain_core." + module_name)
module_name = relative_path.split(".")[0] all_ = getattr(module, "__all__", [])
module = importlib.import_module("langchain_core." + module_name) for cls_ in all_:
all_ = getattr(module, "__all__", []) getattr(module, cls_)
for cls_ in all_:
getattr(module, cls_)
def try_to_import(module_name: str) -> tuple[int, str]: def try_to_import(module_name: str) -> tuple[int, str]:
@ -37,12 +34,10 @@ def test_importable_all_via_subprocess() -> None:
for one sequence of imports but not another. for one sequence of imports but not another.
""" """
module_names = [] module_names = []
for path in glob.glob("../core/langchain_core/*"): for path in Path("../core/langchain_core/").glob("*"):
relative_path = Path(path).parts[-1] module_name = path.stem
if relative_path.endswith(".typed"): if not module_name.startswith(".") and path.suffix != ".typed":
continue module_names.append(module_name)
module_name = relative_path.split(".")[0]
module_names.append(module_name)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [ futures = [

View File

@ -26,6 +26,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import LanguageModelInput from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import ( from langchain_core.language_models.chat_models import (
BaseChatModel, BaseChatModel,
@ -83,6 +84,15 @@ _message_type_lookups = {
} }
class AnthropicTool(TypedDict):
"""Anthropic tool definition."""
name: str
description: str
input_schema: Dict[str, Any]
cache_control: NotRequired[Dict[str, str]]
def _format_image(image_url: str) -> Dict: def _format_image(image_url: str) -> Dict:
""" """
Formats an image of format data:image/jpeg;base64,{b64_string} Formats an image of format data:image/jpeg;base64,{b64_string}
@ -900,6 +910,8 @@ class ChatAnthropic(BaseChatModel):
llm_output = { llm_output = {
k: v for k, v in data_dict.items() if k not in ("content", "role", "type") k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
} }
if "model" in llm_output and "model_name" not in llm_output:
llm_output["model_name"] = llm_output["model"]
if ( if (
len(content) == 1 len(content) == 1
and content[0]["type"] == "text" and content[0]["type"] == "text"
@ -952,6 +964,31 @@ class ChatAnthropic(BaseChatModel):
data = await self._async_client.messages.create(**payload) data = await self._async_client.messages.create(**payload)
return self._format_output(data, **kwargs) return self._format_output(data, **kwargs)
def _get_llm_for_structured_output_when_thinking_is_enabled(
self,
schema: Union[Dict, type],
formatted_tool: AnthropicTool,
) -> Runnable[LanguageModelInput, BaseMessage]:
thinking_admonition = (
"Anthropic structured output relies on forced tool calling, "
"which is not supported when `thinking` is enabled. This method will raise "
"langchain_core.exceptions.OutputParserException if tool calls are not "
"generated. Consider disabling `thinking` or adjust your prompt to ensure "
"the tool is called."
)
warnings.warn(thinking_admonition)
llm = self.bind_tools(
[schema],
structured_output_format={"kwargs": {}, "schema": formatted_tool},
)
def _raise_if_no_tool_calls(message: AIMessage) -> AIMessage:
if not message.tool_calls:
raise OutputParserException(thinking_admonition)
return message
return llm | _raise_if_no_tool_calls
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
@ -1249,11 +1286,17 @@ class ChatAnthropic(BaseChatModel):
""" # noqa: E501 """ # noqa: E501
formatted_tool = convert_to_anthropic_tool(schema) formatted_tool = convert_to_anthropic_tool(schema)
tool_name = formatted_tool["name"] tool_name = formatted_tool["name"]
llm = self.bind_tools( if self.thinking is not None and self.thinking.get("type") == "enabled":
[schema], llm = self._get_llm_for_structured_output_when_thinking_is_enabled(
tool_choice=tool_name, schema, formatted_tool
structured_output_format={"kwargs": {}, "schema": formatted_tool}, )
) else:
llm = self.bind_tools(
[schema],
tool_choice=tool_name,
structured_output_format={"kwargs": {}, "schema": formatted_tool},
)
if isinstance(schema, type) and is_basemodel_subclass(schema): if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True tools=[schema], first_tool_only=True
@ -1356,15 +1399,6 @@ class ChatAnthropic(BaseChatModel):
return response.input_tokens return response.input_tokens
class AnthropicTool(TypedDict):
"""Anthropic tool definition."""
name: str
description: str
input_schema: Dict[str, Any]
cache_control: NotRequired[Dict[str, str]]
def convert_to_anthropic_tool( def convert_to_anthropic_tool(
tool: Union[Dict[str, Any], Type, Callable, BaseTool], tool: Union[Dict[str, Any], Type, Callable, BaseTool],
) -> AnthropicTool: ) -> AnthropicTool:
@ -1445,9 +1479,14 @@ def _make_message_chunk_from_anthropic_event(
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501 # See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
if event.type == "message_start" and stream_usage: if event.type == "message_start" and stream_usage:
usage_metadata = _create_usage_metadata(event.message.usage) usage_metadata = _create_usage_metadata(event.message.usage)
if hasattr(event.message, "model"):
response_metadata = {"model_name": event.message.model}
else:
response_metadata = {}
message_chunk = AIMessageChunk( message_chunk = AIMessageChunk(
content="" if coerce_content_to_string else [], content="" if coerce_content_to_string else [],
usage_metadata=usage_metadata, usage_metadata=usage_metadata,
response_metadata=response_metadata,
) )
elif ( elif (
event.type == "content_block_start" event.type == "content_block_start"

View File

@ -6,7 +6,9 @@ from typing import List, Optional
import pytest import pytest
import requests import requests
from anthropic import BadRequestError
from langchain_core.callbacks import CallbackManager from langchain_core.callbacks import CallbackManager
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
@ -35,6 +37,7 @@ def test_stream() -> None:
full: Optional[BaseMessageChunk] = None full: Optional[BaseMessageChunk] = None
chunks_with_input_token_counts = 0 chunks_with_input_token_counts = 0
chunks_with_output_token_counts = 0 chunks_with_output_token_counts = 0
chunks_with_model_name = 0
for token in llm.stream("I'm Pickle Rick"): for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str) assert isinstance(token.content, str)
full = token if full is None else full + token full = token if full is None else full + token
@ -44,12 +47,14 @@ def test_stream() -> None:
chunks_with_input_token_counts += 1 chunks_with_input_token_counts += 1
elif token.usage_metadata.get("output_tokens"): elif token.usage_metadata.get("output_tokens"):
chunks_with_output_token_counts += 1 chunks_with_output_token_counts += 1
chunks_with_model_name += int("model_name" in token.response_metadata)
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1: if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
raise AssertionError( raise AssertionError(
"Expected exactly one chunk with input or output token counts. " "Expected exactly one chunk with input or output token counts. "
"AIMessageChunk aggregation adds counts. Check that " "AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly." "this is behaving properly."
) )
assert chunks_with_model_name == 1
# check token usage is populated # check token usage is populated
assert isinstance(full, AIMessageChunk) assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None assert full.usage_metadata is not None
@ -62,6 +67,7 @@ def test_stream() -> None:
) )
assert "stop_reason" in full.response_metadata assert "stop_reason" in full.response_metadata
assert "stop_sequence" in full.response_metadata assert "stop_sequence" in full.response_metadata
assert "model_name" in full.response_metadata
async def test_astream() -> None: async def test_astream() -> None:
@ -219,6 +225,7 @@ async def test_ainvoke() -> None:
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str) assert isinstance(result.content, str)
assert "model_name" in result.response_metadata
def test_invoke() -> None: def test_invoke() -> None:
@ -725,3 +732,39 @@ def test_redacted_thinking() -> None:
assert set(block.keys()) == {"type", "data", "index"} assert set(block.keys()) == {"type", "data", "index"}
assert block["data"] and isinstance(block["data"], str) assert block["data"] and isinstance(block["data"], str)
assert stream_has_reasoning assert stream_has_reasoning
def test_structured_output_thinking_enabled() -> None:
llm = ChatAnthropic(
model="claude-3-7-sonnet-latest",
max_tokens=5_000,
thinking={"type": "enabled", "budget_tokens": 2_000},
)
with pytest.warns(match="structured output"):
structured_llm = llm.with_structured_output(GenerateUsername)
query = "Generate a username for Sally with green hair"
response = structured_llm.invoke(query)
assert isinstance(response, GenerateUsername)
with pytest.raises(OutputParserException):
structured_llm.invoke("Hello")
# Test streaming
for chunk in structured_llm.stream(query):
assert isinstance(chunk, GenerateUsername)
def test_structured_output_thinking_force_tool_use() -> None:
# Structured output currently relies on forced tool use, which is not supported
# when `thinking` is enabled. When this test fails, it means that the feature
# is supported and the workarounds in `with_structured_output` should be removed.
llm = ChatAnthropic(
model="claude-3-7-sonnet-latest",
max_tokens=5_000,
thinking={"type": "enabled", "budget_tokens": 2_000},
).bind_tools(
[GenerateUsername],
tool_choice="GenerateUsername",
)
with pytest.raises(BadRequestError):
llm.invoke("Generate a username for Sally with green hair")

View File

@ -579,7 +579,11 @@ class ChatMistralAI(BaseChatModel):
) )
generations.append(gen) generations.append(gen)
llm_output = {"token_usage": token_usage, "model": self.model} llm_output = {
"token_usage": token_usage,
"model_name": self.model,
"model": self.model, # Backwards compatability
}
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts( def _create_message_dicts(

View File

@ -87,6 +87,7 @@ async def test_ainvoke() -> None:
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str) assert isinstance(result.content, str)
assert "model_name" in result.response_metadata
def test_invoke() -> None: def test_invoke() -> None: