mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
Merge branch 'master' into pprados/06-pdfplumber
This commit is contained in:
commit
bd3a24f2d1
4
.github/CODEOWNERS
vendored
4
.github/CODEOWNERS
vendored
@ -1,2 +1,2 @@
|
||||
/.github/ @efriis @baskaryan @ccurme
|
||||
/libs/packages.yml @efriis
|
||||
/.github/ @baskaryan @ccurme
|
||||
/libs/packages.yml @ccurme
|
||||
|
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -26,4 +26,4 @@ Additional guidelines:
|
||||
- Changes should be backwards compatible.
|
||||
- If you are adding something to community, do not re-import it in langchain.
|
||||
|
||||
If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
|
||||
If no one reviews your PR within a few days, please @-mention one of baskaryan, eyurtsev, ccurme, vbarda, hwchase17.
|
||||
|
@ -328,7 +328,7 @@ html[data-theme=dark] .MathJax_SVG * {
|
||||
}
|
||||
|
||||
.bd-sidebar-primary {
|
||||
width: 22%; /* Adjust this value to your preference */
|
||||
width: max-content; /* Adjust this value to your preference */
|
||||
line-height: 1.4;
|
||||
}
|
||||
|
||||
|
@ -35,5 +35,5 @@ Please reference our [Review Process](review_process.mdx).
|
||||
|
||||
### I think my PR was closed in a way that didn't follow the review process. What should I do?
|
||||
|
||||
Tag `@efriis` in the PR comments referencing the portion of the review
|
||||
Tag `@ccurme` in the PR comments referencing the portion of the review
|
||||
process that you believe was not followed. We'll take a look!
|
||||
|
@ -1 +1,2 @@
|
||||
httpx
|
||||
httpx
|
||||
grpcio
|
||||
|
@ -30,6 +30,7 @@ class AscendEmbeddings(Embeddings, BaseModel):
|
||||
document_instruction: str = ""
|
||||
use_fp16: bool = True
|
||||
pooling_method: Optional[str] = "cls"
|
||||
batch_size: int = 32
|
||||
model: Any
|
||||
tokenizer: Any
|
||||
|
||||
@ -119,7 +120,18 @@ class AscendEmbeddings(Embeddings, BaseModel):
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.encode([self.document_instruction + text for text in texts])
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import numpy, please install with `pip install -U numpy`."
|
||||
) from e
|
||||
embedding_list = []
|
||||
for i in range(0, len(texts), self.batch_size):
|
||||
texts_ = texts[i : i + self.batch_size]
|
||||
emb = self.encode([self.document_instruction + text for text in texts_])
|
||||
embedding_list.append(emb)
|
||||
return np.concatenate(embedding_list)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self.encode([self.query_instruction + text])[0]
|
||||
|
@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
@ -12,7 +14,12 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
|
||||
@ -126,6 +133,7 @@ class Xinference(LLM):
|
||||
self,
|
||||
server_url: Optional[str] = None,
|
||||
model_uid: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
**model_kwargs: Any,
|
||||
):
|
||||
try:
|
||||
@ -155,7 +163,13 @@ class Xinference(LLM):
|
||||
if self.model_uid is None:
|
||||
raise ValueError("Please provide the model UID")
|
||||
|
||||
self.client = RESTfulClient(server_url)
|
||||
self._headers: Dict[str, str] = {}
|
||||
self._cluster_authed = False
|
||||
self._check_cluster_authenticated()
|
||||
if api_key is not None and self._cluster_authed:
|
||||
self._headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
self.client = RESTfulClient(server_url, api_key)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@ -171,6 +185,20 @@ class Xinference(LLM):
|
||||
**{"model_kwargs": self.model_kwargs},
|
||||
}
|
||||
|
||||
def _check_cluster_authenticated(self) -> None:
|
||||
url = f"{self.server_url}/v1/cluster/auth"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 404:
|
||||
self._cluster_authed = False
|
||||
else:
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get cluster information, "
|
||||
f"detail: {response.json()['detail']}"
|
||||
)
|
||||
response_data = response.json()
|
||||
self._cluster_authed = bool(response_data["auth"])
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -305,3 +333,61 @@ class Xinference(LLM):
|
||||
return GenerationChunk(text=token)
|
||||
else:
|
||||
raise TypeError("stream_response type error!")
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
generate_config = kwargs.get("generate_config", {})
|
||||
generate_config = {**self.model_kwargs, **generate_config}
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
async for stream_resp in self._acreate_generate_stream(prompt, generate_config):
|
||||
if stream_resp:
|
||||
chunk = self._stream_response_to_generation_chunk(stream_resp)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def _acreate_generate_stream(
|
||||
self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None
|
||||
) -> AsyncIterator[str]:
|
||||
request_body: Dict[str, Any] = {"model": self.model_uid, "prompt": prompt}
|
||||
if generate_config is not None:
|
||||
for key, value in generate_config.items():
|
||||
request_body[key] = value
|
||||
|
||||
stream = bool(generate_config and generate_config.get("stream"))
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url=f"{self.server_url}/v1/completions",
|
||||
json=request_body,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
if response.status == 404:
|
||||
raise FileNotFoundError(
|
||||
"astream call failed with status code 404."
|
||||
)
|
||||
else:
|
||||
optional_detail = response.text
|
||||
raise ValueError(
|
||||
f"astream call failed with status code {response.status}."
|
||||
f" Details: {optional_detail}"
|
||||
)
|
||||
|
||||
async for line in response.content:
|
||||
if not stream:
|
||||
yield json.loads(line)
|
||||
else:
|
||||
json_str = line.decode("utf-8")
|
||||
if line.startswith(b"data:"):
|
||||
json_str = json_str[len(b"data:") :].strip()
|
||||
if not json_str:
|
||||
continue
|
||||
yield json.loads(json_str)
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, TextIO, cast
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
@ -30,7 +31,7 @@ class FileCallbackHandler(BaseCallbackHandler):
|
||||
mode: The mode to open the file in. Defaults to "a".
|
||||
color: The color to use for the text. Defaults to None.
|
||||
"""
|
||||
self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) # noqa: SIM115
|
||||
self.file = cast(TextIO, Path(filename).open(mode, encoding="utf-8")) # noqa: SIM115
|
||||
self.color = color
|
||||
|
||||
def __del__(self) -> None:
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import mimetypes
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from pathlib import Path, PurePath
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
@ -151,8 +151,7 @@ class Blob(BaseMedia):
|
||||
def as_string(self) -> str:
|
||||
"""Read data as a string."""
|
||||
if self.data is None and self.path:
|
||||
with open(str(self.path), encoding=self.encoding) as f:
|
||||
return f.read()
|
||||
return Path(self.path).read_text(encoding=self.encoding)
|
||||
elif isinstance(self.data, bytes):
|
||||
return self.data.decode(self.encoding)
|
||||
elif isinstance(self.data, str):
|
||||
@ -168,8 +167,7 @@ class Blob(BaseMedia):
|
||||
elif isinstance(self.data, str):
|
||||
return self.data.encode(self.encoding)
|
||||
elif self.data is None and self.path:
|
||||
with open(str(self.path), "rb") as f:
|
||||
return f.read()
|
||||
return Path(self.path).read_bytes()
|
||||
else:
|
||||
msg = f"Unable to get bytes for blob {self}"
|
||||
raise ValueError(msg)
|
||||
@ -180,7 +178,7 @@ class Blob(BaseMedia):
|
||||
if isinstance(self.data, bytes):
|
||||
yield BytesIO(self.data)
|
||||
elif self.data is None and self.path:
|
||||
with open(str(self.path), "rb") as f:
|
||||
with Path(self.path).open("rb") as f:
|
||||
yield f
|
||||
else:
|
||||
msg = f"Unable to convert blob {self}"
|
||||
|
@ -1402,7 +1402,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
llm.save(file_path="path/llm.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
save_path = Path(file_path) if isinstance(file_path, str) else file_path
|
||||
save_path = Path(file_path)
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
@ -1411,10 +1411,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
prompt_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
with save_path.open("w") as f:
|
||||
json.dump(prompt_dict, f, indent=4)
|
||||
elif save_path.suffix.endswith((".yaml", ".yml")):
|
||||
with open(file_path, "w") as f:
|
||||
with save_path.open("w") as f:
|
||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||
else:
|
||||
msg = f"{save_path} must be json or yaml"
|
||||
|
@ -368,16 +368,16 @@ class BasePromptTemplate(
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
# Convert file to Path object.
|
||||
save_path = Path(file_path) if isinstance(file_path, str) else file_path
|
||||
save_path = Path(file_path)
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
with save_path.open("w") as f:
|
||||
json.dump(prompt_dict, f, indent=4)
|
||||
elif save_path.suffix.endswith((".yaml", ".yml")):
|
||||
with open(file_path, "w") as f:
|
||||
with save_path.open("w") as f:
|
||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||
else:
|
||||
msg = f"{save_path} must be json or yaml"
|
||||
|
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@ -48,7 +49,6 @@ from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
@ -599,8 +599,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
with open(str(template_file)) as f:
|
||||
template = f.read()
|
||||
template = Path(template_file).read_text()
|
||||
return cls.from_template(template, input_variables=input_variables, **kwargs)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
|
@ -53,8 +53,7 @@ def _load_template(var_name: str, config: dict) -> dict:
|
||||
template_path = Path(config.pop(f"{var_name}_path"))
|
||||
# Load the template.
|
||||
if template_path.suffix == ".txt":
|
||||
with open(template_path) as f:
|
||||
template = f.read()
|
||||
template = template_path.read_text()
|
||||
else:
|
||||
raise ValueError
|
||||
# Set the template variable to the extracted variable.
|
||||
@ -67,10 +66,11 @@ def _load_examples(config: dict) -> dict:
|
||||
if isinstance(config["examples"], list):
|
||||
pass
|
||||
elif isinstance(config["examples"], str):
|
||||
with open(config["examples"]) as f:
|
||||
if config["examples"].endswith(".json"):
|
||||
path = Path(config["examples"])
|
||||
with path.open() as f:
|
||||
if path.suffix == ".json":
|
||||
examples = json.load(f)
|
||||
elif config["examples"].endswith((".yaml", ".yml")):
|
||||
elif path.suffix in {".yaml", ".yml"}:
|
||||
examples = yaml.safe_load(f)
|
||||
else:
|
||||
msg = "Invalid file format. Only json or yaml formats are supported."
|
||||
@ -168,13 +168,13 @@ def _load_prompt_from_file(
|
||||
) -> BasePromptTemplate:
|
||||
"""Load prompt from file."""
|
||||
# Convert file to a Path object.
|
||||
file_path = Path(file) if isinstance(file, str) else file
|
||||
file_path = Path(file)
|
||||
# Load from either json or yaml.
|
||||
if file_path.suffix == ".json":
|
||||
with open(file_path, encoding=encoding) as f:
|
||||
with file_path.open(encoding=encoding) as f:
|
||||
config = json.load(f)
|
||||
elif file_path.suffix.endswith((".yaml", ".yml")):
|
||||
with open(file_path, encoding=encoding) as f:
|
||||
with file_path.open(encoding=encoding) as f:
|
||||
config = yaml.safe_load(f)
|
||||
else:
|
||||
msg = f"Got unsupported file type {file_path.suffix}"
|
||||
|
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -17,8 +18,6 @@ from langchain_core.prompts.string import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
@ -238,8 +237,7 @@ class PromptTemplate(StringPromptTemplate):
|
||||
Returns:
|
||||
The prompt loaded from the file.
|
||||
"""
|
||||
with open(str(template_file), encoding=encoding) as f:
|
||||
template = f.read()
|
||||
template = Path(template_file).read_text(encoding=encoding)
|
||||
if input_variables:
|
||||
warnings.warn(
|
||||
"`input_variables' is deprecated and ignored.",
|
||||
|
@ -1644,8 +1644,13 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_core.runnables import RunnableLambda, Runnable
|
||||
from datetime import datetime, timezone
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
def format_t(timestamp: float) -> str:
|
||||
return datetime.fromtimestamp(timestamp, tz=timezone.utc).isoformat()
|
||||
|
||||
async def test_runnable(time_to_sleep : int):
|
||||
print(f"Runnable[{time_to_sleep}s]: starts at {format_t(time.time())}")
|
||||
@ -1653,12 +1658,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
print(f"Runnable[{time_to_sleep}s]: ends at {format_t(time.time())}")
|
||||
|
||||
async def fn_start(run_obj : Runnable):
|
||||
print(f"on start callback starts at {format_t(time.time())}
|
||||
print(f"on start callback starts at {format_t(time.time())}")
|
||||
await asyncio.sleep(3)
|
||||
print(f"on start callback ends at {format_t(time.time())}")
|
||||
|
||||
async def fn_end(run_obj : Runnable):
|
||||
print(f"on end callback starts at {format_t(time.time())}
|
||||
print(f"on end callback starts at {format_t(time.time())}")
|
||||
await asyncio.sleep(2)
|
||||
print(f"on end callback ends at {format_t(time.time())}")
|
||||
|
||||
@ -1671,18 +1676,18 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
asyncio.run(concurrent_runs())
|
||||
Result:
|
||||
on start callback starts at 2024-05-16T14:20:29.637053+00:00
|
||||
on start callback starts at 2024-05-16T14:20:29.637150+00:00
|
||||
on start callback ends at 2024-05-16T14:20:32.638305+00:00
|
||||
on start callback ends at 2024-05-16T14:20:32.638383+00:00
|
||||
Runnable[3s]: starts at 2024-05-16T14:20:32.638849+00:00
|
||||
Runnable[5s]: starts at 2024-05-16T14:20:32.638999+00:00
|
||||
Runnable[3s]: ends at 2024-05-16T14:20:35.640016+00:00
|
||||
on end callback starts at 2024-05-16T14:20:35.640534+00:00
|
||||
Runnable[5s]: ends at 2024-05-16T14:20:37.640169+00:00
|
||||
on end callback starts at 2024-05-16T14:20:37.640574+00:00
|
||||
on end callback ends at 2024-05-16T14:20:37.640654+00:00
|
||||
on end callback ends at 2024-05-16T14:20:39.641751+00:00
|
||||
on start callback starts at 2025-03-01T07:05:22.875378+00:00
|
||||
on start callback starts at 2025-03-01T07:05:22.875495+00:00
|
||||
on start callback ends at 2025-03-01T07:05:25.878862+00:00
|
||||
on start callback ends at 2025-03-01T07:05:25.878947+00:00
|
||||
Runnable[2s]: starts at 2025-03-01T07:05:25.879392+00:00
|
||||
Runnable[3s]: starts at 2025-03-01T07:05:25.879804+00:00
|
||||
Runnable[2s]: ends at 2025-03-01T07:05:27.881998+00:00
|
||||
on end callback starts at 2025-03-01T07:05:27.882360+00:00
|
||||
Runnable[3s]: ends at 2025-03-01T07:05:28.881737+00:00
|
||||
on end callback starts at 2025-03-01T07:05:28.882428+00:00
|
||||
on end callback ends at 2025-03-01T07:05:29.883893+00:00
|
||||
on end callback ends at 2025-03-01T07:05:30.884831+00:00
|
||||
|
||||
"""
|
||||
from langchain_core.tracers.root_listeners import AsyncRootListenersTracer
|
||||
|
@ -2,6 +2,7 @@ import asyncio
|
||||
import base64
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
from langchain_core.runnables.graph import (
|
||||
@ -290,13 +291,9 @@ async def _render_mermaid_using_pyppeteer(
|
||||
img_bytes = await page.screenshot({"fullPage": False})
|
||||
await browser.close()
|
||||
|
||||
def write_to_file(path: str, bytes: bytes) -> None:
|
||||
with open(path, "wb") as file:
|
||||
file.write(bytes)
|
||||
|
||||
if output_file_path is not None:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, write_to_file, output_file_path, img_bytes
|
||||
None, Path(output_file_path).write_bytes, img_bytes
|
||||
)
|
||||
|
||||
return img_bytes
|
||||
@ -337,8 +334,7 @@ def _render_mermaid_using_api(
|
||||
if response.status_code == 200:
|
||||
img_bytes = response.content
|
||||
if output_file_path is not None:
|
||||
with open(output_file_path, "wb") as file:
|
||||
file.write(response.content)
|
||||
Path(output_file_path).write_bytes(response.content)
|
||||
|
||||
return img_bytes
|
||||
else:
|
||||
|
@ -77,7 +77,7 @@ target-version = "py39"
|
||||
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [ "ANN", "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TC", "TID", "TRY", "UP", "W", "YTT",]
|
||||
select = [ "ANN", "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "PTH", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TC", "TID", "TRY", "UP", "W", "YTT",]
|
||||
ignore = [ "ANN401", "COM812", "UP007", "S110", "S112", "TC001", "TC002", "TC003"]
|
||||
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
|
||||
flake8-annotations.allow-star-arg-any = true
|
||||
|
@ -354,7 +354,7 @@ def test_prompt_from_file_with_partial_variables() -> None:
|
||||
template = "This is a {foo} test {bar}."
|
||||
partial_variables = {"bar": "baz"}
|
||||
# when
|
||||
with mock.patch("builtins.open", mock.mock_open(read_data=template)):
|
||||
with mock.patch("pathlib.Path.open", mock.mock_open(read_data=template)):
|
||||
prompt = PromptTemplate.from_file(
|
||||
"mock_file_name", partial_variables=partial_variables
|
||||
)
|
||||
|
@ -1,20 +1,17 @@
|
||||
import concurrent.futures
|
||||
import glob
|
||||
import importlib
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_importable_all() -> None:
|
||||
for path in glob.glob("../core/langchain_core/*"):
|
||||
relative_path = Path(path).parts[-1]
|
||||
if relative_path.endswith(".typed"):
|
||||
continue
|
||||
module_name = relative_path.split(".")[0]
|
||||
module = importlib.import_module("langchain_core." + module_name)
|
||||
all_ = getattr(module, "__all__", [])
|
||||
for cls_ in all_:
|
||||
getattr(module, cls_)
|
||||
for path in Path("../core/langchain_core/").glob("*"):
|
||||
module_name = path.stem
|
||||
if not module_name.startswith(".") and path.suffix != ".typed":
|
||||
module = importlib.import_module("langchain_core." + module_name)
|
||||
all_ = getattr(module, "__all__", [])
|
||||
for cls_ in all_:
|
||||
getattr(module, cls_)
|
||||
|
||||
|
||||
def try_to_import(module_name: str) -> tuple[int, str]:
|
||||
@ -37,12 +34,10 @@ def test_importable_all_via_subprocess() -> None:
|
||||
for one sequence of imports but not another.
|
||||
"""
|
||||
module_names = []
|
||||
for path in glob.glob("../core/langchain_core/*"):
|
||||
relative_path = Path(path).parts[-1]
|
||||
if relative_path.endswith(".typed"):
|
||||
continue
|
||||
module_name = relative_path.split(".")[0]
|
||||
module_names.append(module_name)
|
||||
for path in Path("../core/langchain_core/").glob("*"):
|
||||
module_name = path.stem
|
||||
if not module_name.startswith(".") and path.suffix != ".typed":
|
||||
module_names.append(module_name)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [
|
||||
|
@ -26,6 +26,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
@ -83,6 +84,15 @@ _message_type_lookups = {
|
||||
}
|
||||
|
||||
|
||||
class AnthropicTool(TypedDict):
|
||||
"""Anthropic tool definition."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: Dict[str, Any]
|
||||
cache_control: NotRequired[Dict[str, str]]
|
||||
|
||||
|
||||
def _format_image(image_url: str) -> Dict:
|
||||
"""
|
||||
Formats an image of format data:image/jpeg;base64,{b64_string}
|
||||
@ -900,6 +910,8 @@ class ChatAnthropic(BaseChatModel):
|
||||
llm_output = {
|
||||
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
|
||||
}
|
||||
if "model" in llm_output and "model_name" not in llm_output:
|
||||
llm_output["model_name"] = llm_output["model"]
|
||||
if (
|
||||
len(content) == 1
|
||||
and content[0]["type"] == "text"
|
||||
@ -952,6 +964,31 @@ class ChatAnthropic(BaseChatModel):
|
||||
data = await self._async_client.messages.create(**payload)
|
||||
return self._format_output(data, **kwargs)
|
||||
|
||||
def _get_llm_for_structured_output_when_thinking_is_enabled(
|
||||
self,
|
||||
schema: Union[Dict, type],
|
||||
formatted_tool: AnthropicTool,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
thinking_admonition = (
|
||||
"Anthropic structured output relies on forced tool calling, "
|
||||
"which is not supported when `thinking` is enabled. This method will raise "
|
||||
"langchain_core.exceptions.OutputParserException if tool calls are not "
|
||||
"generated. Consider disabling `thinking` or adjust your prompt to ensure "
|
||||
"the tool is called."
|
||||
)
|
||||
warnings.warn(thinking_admonition)
|
||||
llm = self.bind_tools(
|
||||
[schema],
|
||||
structured_output_format={"kwargs": {}, "schema": formatted_tool},
|
||||
)
|
||||
|
||||
def _raise_if_no_tool_calls(message: AIMessage) -> AIMessage:
|
||||
if not message.tool_calls:
|
||||
raise OutputParserException(thinking_admonition)
|
||||
return message
|
||||
|
||||
return llm | _raise_if_no_tool_calls
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
@ -1249,11 +1286,17 @@ class ChatAnthropic(BaseChatModel):
|
||||
""" # noqa: E501
|
||||
formatted_tool = convert_to_anthropic_tool(schema)
|
||||
tool_name = formatted_tool["name"]
|
||||
llm = self.bind_tools(
|
||||
[schema],
|
||||
tool_choice=tool_name,
|
||||
structured_output_format={"kwargs": {}, "schema": formatted_tool},
|
||||
)
|
||||
if self.thinking is not None and self.thinking.get("type") == "enabled":
|
||||
llm = self._get_llm_for_structured_output_when_thinking_is_enabled(
|
||||
schema, formatted_tool
|
||||
)
|
||||
else:
|
||||
llm = self.bind_tools(
|
||||
[schema],
|
||||
tool_choice=tool_name,
|
||||
structured_output_format={"kwargs": {}, "schema": formatted_tool},
|
||||
)
|
||||
|
||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], first_tool_only=True
|
||||
@ -1356,15 +1399,6 @@ class ChatAnthropic(BaseChatModel):
|
||||
return response.input_tokens
|
||||
|
||||
|
||||
class AnthropicTool(TypedDict):
|
||||
"""Anthropic tool definition."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: Dict[str, Any]
|
||||
cache_control: NotRequired[Dict[str, str]]
|
||||
|
||||
|
||||
def convert_to_anthropic_tool(
|
||||
tool: Union[Dict[str, Any], Type, Callable, BaseTool],
|
||||
) -> AnthropicTool:
|
||||
@ -1445,9 +1479,14 @@ def _make_message_chunk_from_anthropic_event(
|
||||
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
|
||||
if event.type == "message_start" and stream_usage:
|
||||
usage_metadata = _create_usage_metadata(event.message.usage)
|
||||
if hasattr(event.message, "model"):
|
||||
response_metadata = {"model_name": event.message.model}
|
||||
else:
|
||||
response_metadata = {}
|
||||
message_chunk = AIMessageChunk(
|
||||
content="" if coerce_content_to_string else [],
|
||||
usage_metadata=usage_metadata,
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
elif (
|
||||
event.type == "content_block_start"
|
||||
|
@ -6,7 +6,9 @@ from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from anthropic import BadRequestError
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -35,6 +37,7 @@ def test_stream() -> None:
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_input_token_counts = 0
|
||||
chunks_with_output_token_counts = 0
|
||||
chunks_with_model_name = 0
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
full = token if full is None else full + token
|
||||
@ -44,12 +47,14 @@ def test_stream() -> None:
|
||||
chunks_with_input_token_counts += 1
|
||||
elif token.usage_metadata.get("output_tokens"):
|
||||
chunks_with_output_token_counts += 1
|
||||
chunks_with_model_name += int("model_name" in token.response_metadata)
|
||||
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one chunk with input or output token counts. "
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
assert chunks_with_model_name == 1
|
||||
# check token usage is populated
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.usage_metadata is not None
|
||||
@ -62,6 +67,7 @@ def test_stream() -> None:
|
||||
)
|
||||
assert "stop_reason" in full.response_metadata
|
||||
assert "stop_sequence" in full.response_metadata
|
||||
assert "model_name" in full.response_metadata
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
@ -219,6 +225,7 @@ async def test_ainvoke() -> None:
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
assert "model_name" in result.response_metadata
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
@ -725,3 +732,39 @@ def test_redacted_thinking() -> None:
|
||||
assert set(block.keys()) == {"type", "data", "index"}
|
||||
assert block["data"] and isinstance(block["data"], str)
|
||||
assert stream_has_reasoning
|
||||
|
||||
|
||||
def test_structured_output_thinking_enabled() -> None:
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
max_tokens=5_000,
|
||||
thinking={"type": "enabled", "budget_tokens": 2_000},
|
||||
)
|
||||
with pytest.warns(match="structured output"):
|
||||
structured_llm = llm.with_structured_output(GenerateUsername)
|
||||
query = "Generate a username for Sally with green hair"
|
||||
response = structured_llm.invoke(query)
|
||||
assert isinstance(response, GenerateUsername)
|
||||
|
||||
with pytest.raises(OutputParserException):
|
||||
structured_llm.invoke("Hello")
|
||||
|
||||
# Test streaming
|
||||
for chunk in structured_llm.stream(query):
|
||||
assert isinstance(chunk, GenerateUsername)
|
||||
|
||||
|
||||
def test_structured_output_thinking_force_tool_use() -> None:
|
||||
# Structured output currently relies on forced tool use, which is not supported
|
||||
# when `thinking` is enabled. When this test fails, it means that the feature
|
||||
# is supported and the workarounds in `with_structured_output` should be removed.
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
max_tokens=5_000,
|
||||
thinking={"type": "enabled", "budget_tokens": 2_000},
|
||||
).bind_tools(
|
||||
[GenerateUsername],
|
||||
tool_choice="GenerateUsername",
|
||||
)
|
||||
with pytest.raises(BadRequestError):
|
||||
llm.invoke("Generate a username for Sally with green hair")
|
||||
|
@ -579,7 +579,11 @@ class ChatMistralAI(BaseChatModel):
|
||||
)
|
||||
generations.append(gen)
|
||||
|
||||
llm_output = {"token_usage": token_usage, "model": self.model}
|
||||
llm_output = {
|
||||
"token_usage": token_usage,
|
||||
"model_name": self.model,
|
||||
"model": self.model, # Backwards compatability
|
||||
}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
|
@ -87,6 +87,7 @@ async def test_ainvoke() -> None:
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
assert "model_name" in result.response_metadata
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user