refactor(anthropic): AnthropicLLM to use Messages API (#32290)

re: #32189
This commit is contained in:
Mason Daugherty 2025-07-28 16:22:58 -04:00 committed by GitHub
parent e5fd67024c
commit 3a487bf720
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 134 additions and 112 deletions

View File

@ -3,11 +3,7 @@ from __future__ import annotations
import re
import warnings
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import (
Any,
Callable,
Optional,
)
from typing import Any, Callable, Optional
import anthropic
from langchain_core._api.deprecation import deprecated
@ -19,14 +15,8 @@ from langchain_core.language_models import BaseLanguageModel, LangSmithParams
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.prompt_values import PromptValue
from langchain_core.utils import (
get_pydantic_field_names,
)
from langchain_core.utils.utils import (
_build_model_kwargs,
from_env,
secret_from_env,
)
from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
@ -34,10 +24,10 @@ from typing_extensions import Self
class _AnthropicCommon(BaseLanguageModel):
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
model: str = Field(default="claude-2", alias="model_name")
model: str = Field(default="claude-3-5-sonnet-latest", alias="model_name")
"""Model name to use."""
max_tokens_to_sample: int = Field(default=1024, alias="max_tokens")
max_tokens: int = Field(default=1024, alias="max_tokens_to_sample")
"""Denotes the number of tokens to predict per generation."""
temperature: Optional[float] = None
@ -104,15 +94,16 @@ class _AnthropicCommon(BaseLanguageModel):
timeout=self.default_request_timeout,
max_retries=self.max_retries,
)
self.HUMAN_PROMPT = anthropic.HUMAN_PROMPT
self.AI_PROMPT = anthropic.AI_PROMPT
# Keep for backward compatibility but not used in Messages API
self.HUMAN_PROMPT = getattr(anthropic, "HUMAN_PROMPT", None)
self.AI_PROMPT = getattr(anthropic, "AI_PROMPT", None)
return self
@property
def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling Anthropic API."""
d = {
"max_tokens_to_sample": self.max_tokens_to_sample,
"max_tokens": self.max_tokens,
"model": self.model,
}
if self.temperature is not None:
@ -129,16 +120,8 @@ class _AnthropicCommon(BaseLanguageModel):
return {**self._default_params}
def _get_anthropic_stop(self, stop: Optional[list[str]] = None) -> list[str]:
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
msg = "Please ensure the anthropic package is loaded"
raise NameError(msg)
if stop is None:
stop = []
# Never want model to invent new turns of Human / Assistant dialog.
stop.extend([self.HUMAN_PROMPT])
return stop
@ -192,7 +175,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
"""Get the identifying parameters."""
return {
"model": self.model,
"max_tokens": self.max_tokens_to_sample,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
@ -211,27 +194,51 @@ class AnthropicLLM(LLM, _AnthropicCommon):
params = super()._get_ls_params(stop=stop, **kwargs)
identifying_params = self._identifying_params
if max_tokens := kwargs.get(
"max_tokens_to_sample",
"max_tokens",
identifying_params.get("max_tokens"),
):
params["ls_max_tokens"] = max_tokens
return params
def _wrap_prompt(self, prompt: str) -> str:
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
msg = "Please ensure the anthropic package is loaded"
raise NameError(msg)
def _format_messages(self, prompt: str) -> list[dict[str, str]]:
"""Convert prompt to Messages API format."""
messages = []
if prompt.startswith(self.HUMAN_PROMPT):
return prompt # Already wrapped.
# Handle legacy prompts that might have HUMAN_PROMPT/AI_PROMPT markers
if self.HUMAN_PROMPT and self.HUMAN_PROMPT in prompt:
# Split on human/assistant turns
parts = prompt.split(self.HUMAN_PROMPT)
# Guard against common errors in specifying wrong number of newlines.
corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt)
if n_subs == 1:
return corrected_prompt
for _, part in enumerate(parts):
if not part.strip():
continue
# As a last resort, wrap the prompt ourselves to emulate instruct-style.
return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n"
if self.AI_PROMPT and self.AI_PROMPT in part:
# Split human and assistant parts
human_part, assistant_part = part.split(self.AI_PROMPT, 1)
if human_part.strip():
messages.append({"role": "user", "content": human_part.strip()})
if assistant_part.strip():
messages.append(
{"role": "assistant", "content": assistant_part.strip()}
)
else:
# Just human content
if part.strip():
messages.append({"role": "user", "content": part.strip()})
else:
# Handle modern format or plain text
# Clean prompt for Messages API
content = re.sub(r"^\n*Human:\s*", "", prompt)
content = re.sub(r"\n*Assistant:\s*.*$", "", content)
if content.strip():
messages.append({"role": "user", "content": content.strip()})
# Ensure we have at least one message
if not messages:
messages = [{"role": "user", "content": prompt.strip() or "Hello"}]
return messages
def _call(
self,
@ -272,15 +279,19 @@ class AnthropicLLM(LLM, _AnthropicCommon):
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
response = self.client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
# Remove parameters not supported by Messages API
params = {k: v for k, v in params.items() if k != "max_tokens_to_sample"}
response = self.client.messages.create(
messages=self._format_messages(prompt),
stop_sequences=stop if stop else None,
**params,
)
return response.completion
return response.content[0].text
def convert_prompt(self, prompt: PromptValue) -> str:
return self._wrap_prompt(prompt.to_string())
return prompt.to_string()
async def _acall(
self,
@ -304,12 +315,15 @@ class AnthropicLLM(LLM, _AnthropicCommon):
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
response = await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
# Remove parameters not supported by Messages API
params = {k: v for k, v in params.items() if k != "max_tokens_to_sample"}
response = await self.async_client.messages.create(
messages=self._format_messages(prompt),
stop_sequences=stop if stop else None,
**params,
)
return response.completion
return response.content[0].text
def _stream(
self,
@ -343,14 +357,17 @@ class AnthropicLLM(LLM, _AnthropicCommon):
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
for token in self.client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**params,
):
chunk = GenerationChunk(text=token.completion)
# Remove parameters not supported by Messages API
params = {k: v for k, v in params.items() if k != "max_tokens_to_sample"}
with self.client.messages.stream(
messages=self._format_messages(prompt),
stop_sequences=stop if stop else None,
**params,
) as stream:
for event in stream:
if event.type == "content_block_delta" and hasattr(event.delta, "text"):
chunk = GenerationChunk(text=event.delta.text)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
@ -386,14 +403,17 @@ class AnthropicLLM(LLM, _AnthropicCommon):
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
async for token in await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**params,
):
chunk = GenerationChunk(text=token.completion)
# Remove parameters not supported by Messages API
params = {k: v for k, v in params.items() if k != "max_tokens_to_sample"}
async with self.async_client.messages.stream(
messages=self._format_messages(prompt),
stop_sequences=stop if stop else None,
**params,
) as stream:
async for event in stream:
if event.type == "content_block_delta" and hasattr(event.delta, "text"):
chunk = GenerationChunk(text=event.delta.text)
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk

View File

@ -31,8 +31,8 @@ from pydantic import BaseModel, Field
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
from tests.unit_tests._utils import FakeCallbackHandler
MODEL_NAME = "claude-3-5-haiku-latest"
IMAGE_MODEL_NAME = "claude-3-5-sonnet-latest"
MODEL_NAME = "claude-3-5-haiku-20241022"
IMAGE_MODEL_NAME = "claude-3-5-sonnet-20241022"
def test_stream() -> None:
@ -178,7 +178,7 @@ async def test_abatch_tags() -> None:
async def test_async_tool_use() -> None:
llm = ChatAnthropic(
model=MODEL_NAME,
model=MODEL_NAME, # type: ignore[call-arg]
)
llm_with_tools = llm.bind_tools(
@ -274,7 +274,7 @@ def test_system_invoke() -> None:
def test_anthropic_call() -> None:
"""Test valid call to anthropic."""
chat = ChatAnthropic(model=MODEL_NAME)
chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
message = HumanMessage(content="Hello")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
@ -283,7 +283,7 @@ def test_anthropic_call() -> None:
def test_anthropic_generate() -> None:
"""Test generate method of anthropic."""
chat = ChatAnthropic(model=MODEL_NAME)
chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
chat_messages: list[list[BaseMessage]] = [
[HumanMessage(content="How many toes do dogs have?")],
]
@ -299,7 +299,7 @@ def test_anthropic_generate() -> None:
def test_anthropic_streaming() -> None:
"""Test streaming tokens from anthropic."""
chat = ChatAnthropic(model=MODEL_NAME)
chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
message = HumanMessage(content="Hello")
response = chat.stream([message])
for token in response:
@ -312,7 +312,7 @@ def test_anthropic_streaming_callback() -> None:
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatAnthropic(
model=MODEL_NAME,
model=MODEL_NAME, # type: ignore[call-arg]
callback_manager=callback_manager,
verbose=True,
)
@ -328,7 +328,7 @@ async def test_anthropic_async_streaming_callback() -> None:
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatAnthropic(
model=MODEL_NAME,
model=MODEL_NAME, # type: ignore[call-arg]
callback_manager=callback_manager,
verbose=True,
)
@ -343,7 +343,7 @@ async def test_anthropic_async_streaming_callback() -> None:
def test_anthropic_multimodal() -> None:
"""Test that multimodal inputs are handled correctly."""
chat = ChatAnthropic(model=IMAGE_MODEL_NAME)
chat = ChatAnthropic(model=IMAGE_MODEL_NAME) # type: ignore[call-arg]
messages: list[BaseMessage] = [
HumanMessage(
content=[
@ -399,7 +399,7 @@ async def test_astreaming() -> None:
def test_tool_use() -> None:
llm = ChatAnthropic(
model="claude-3-7-sonnet-20250219",
model="claude-3-7-sonnet-20250219", # type: ignore[call-arg]
temperature=0,
)
tool_definition = {
@ -424,7 +424,7 @@ def test_tool_use() -> None:
# Test streaming
llm = ChatAnthropic(
model="claude-3-7-sonnet-20250219",
model="claude-3-7-sonnet-20250219", # type: ignore[call-arg]
temperature=0,
# Add extra headers to also test token-efficient tools
model_kwargs={
@ -492,7 +492,7 @@ def test_tool_use() -> None:
def test_builtin_tools() -> None:
llm = ChatAnthropic(model="claude-3-7-sonnet-20250219")
llm = ChatAnthropic(model="claude-3-7-sonnet-20250219") # type: ignore[call-arg]
tool = {"type": "text_editor_20250124", "name": "str_replace_editor"}
llm_with_tools = llm.bind_tools([tool])
response = llm_with_tools.invoke(
@ -510,7 +510,7 @@ class GenerateUsername(BaseModel):
def test_disable_parallel_tool_calling() -> None:
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022")
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022") # type: ignore[call-arg]
llm_with_tools = llm.bind_tools([GenerateUsername], parallel_tool_calls=False)
result = llm_with_tools.invoke(
"Use the GenerateUsername tool to generate user names for:\n\n"
@ -529,7 +529,7 @@ def test_anthropic_with_empty_text_block() -> None:
"""Type the given letter."""
return "OK"
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0).bind_tools(
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0).bind_tools( # type: ignore[call-arg]
[type_letter],
)
@ -568,7 +568,7 @@ def test_anthropic_with_empty_text_block() -> None:
def test_with_structured_output() -> None:
llm = ChatAnthropic(
model="claude-3-opus-20240229",
model="claude-3-opus-20240229", # type: ignore[call-arg]
)
structured_llm = llm.with_structured_output(
@ -587,7 +587,7 @@ def test_with_structured_output() -> None:
def test_get_num_tokens_from_messages() -> None:
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022")
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022") # type: ignore[call-arg]
# Test simple case
messages = [
@ -650,7 +650,7 @@ class GetWeather(BaseModel):
@pytest.mark.parametrize("tool_choice", ["GetWeather", "auto", "any"])
def test_anthropic_bind_tools_tool_choice(tool_choice: str) -> None:
chat_model = ChatAnthropic(
model=MODEL_NAME,
model=MODEL_NAME, # type: ignore[call-arg]
)
chat_model_with_tools = chat_model.bind_tools([GetWeather], tool_choice=tool_choice)
response = chat_model_with_tools.invoke("what's the weather in ny and la")
@ -661,7 +661,7 @@ def test_pdf_document_input() -> None:
url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
data = b64encode(requests.get(url, timeout=10).content).decode()
result = ChatAnthropic(model=IMAGE_MODEL_NAME).invoke(
result = ChatAnthropic(model=IMAGE_MODEL_NAME).invoke( # type: ignore[call-arg]
[
HumanMessage(
[
@ -684,7 +684,7 @@ def test_pdf_document_input() -> None:
def test_citations() -> None:
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
llm = ChatAnthropic(model="claude-3-5-haiku-latest") # type: ignore[call-arg]
messages = [
{
"role": "user",
@ -729,8 +729,8 @@ def test_citations() -> None:
@pytest.mark.vcr
def test_thinking() -> None:
llm = ChatAnthropic(
model="claude-3-7-sonnet-latest",
max_tokens=5_000,
model="claude-3-7-sonnet-latest", # type: ignore[call-arg]
max_tokens=5_000, # type: ignore[call-arg]
thinking={"type": "enabled", "budget_tokens": 2_000},
)
@ -766,8 +766,8 @@ def test_thinking() -> None:
@pytest.mark.vcr
def test_redacted_thinking() -> None:
llm = ChatAnthropic(
model="claude-3-7-sonnet-latest",
max_tokens=5_000,
model="claude-3-7-sonnet-latest", # type: ignore[call-arg]
max_tokens=5_000, # type: ignore[call-arg]
thinking={"type": "enabled", "budget_tokens": 2_000},
)
query = "ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A98432ECCCE4C1253D5E2D82641AC0E52CC2876CB" # noqa: E501
@ -805,8 +805,8 @@ def test_redacted_thinking() -> None:
def test_structured_output_thinking_enabled() -> None:
llm = ChatAnthropic(
model="claude-3-7-sonnet-latest",
max_tokens=5_000,
model="claude-3-7-sonnet-latest", # type: ignore[call-arg]
max_tokens=5_000, # type: ignore[call-arg]
thinking={"type": "enabled", "budget_tokens": 2_000},
)
with pytest.warns(match="structured output"):
@ -828,8 +828,8 @@ def test_structured_output_thinking_force_tool_use() -> None:
# when `thinking` is enabled. When this test fails, it means that the feature
# is supported and the workarounds in `with_structured_output` should be removed.
llm = ChatAnthropic(
model="claude-3-7-sonnet-latest",
max_tokens=5_000,
model="claude-3-7-sonnet-latest", # type: ignore[call-arg]
max_tokens=5_000, # type: ignore[call-arg]
thinking={"type": "enabled", "budget_tokens": 2_000},
).bind_tools(
[GenerateUsername],
@ -896,13 +896,13 @@ def test_image_tool_calling() -> None:
],
),
]
llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
llm = ChatAnthropic(model="claude-3-5-sonnet-latest") # type: ignore[call-arg]
llm.bind_tools([color_picker]).invoke(messages)
@pytest.mark.vcr
def test_web_search() -> None:
llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
llm = ChatAnthropic(model="claude-3-5-sonnet-latest") # type: ignore[call-arg]
tool = {"type": "web_search_20250305", "name": "web_search", "max_uses": 1}
llm_with_tools = llm.bind_tools([tool])
@ -944,9 +944,9 @@ def test_web_search() -> None:
@pytest.mark.vcr
def test_code_execution() -> None:
llm = ChatAnthropic(
model="claude-sonnet-4-20250514",
model="claude-sonnet-4-20250514", # type: ignore[call-arg]
betas=["code-execution-2025-05-22"],
max_tokens=10_000,
max_tokens=10_000, # type: ignore[call-arg]
)
tool = {"type": "code_execution_20250522", "name": "code_execution"}
@ -1002,10 +1002,10 @@ def test_remote_mcp() -> None:
]
llm = ChatAnthropic(
model="claude-sonnet-4-20250514",
model="claude-sonnet-4-20250514", # type: ignore[call-arg]
betas=["mcp-client-2025-04-04"],
mcp_servers=mcp_servers,
max_tokens=10_000,
max_tokens=10_000, # type: ignore[call-arg]
)
input_message = {
@ -1052,7 +1052,7 @@ def test_files_api_image(block_format: str) -> None:
if not image_file_id:
pytest.skip()
llm = ChatAnthropic(
model="claude-sonnet-4-20250514",
model="claude-sonnet-4-20250514", # type: ignore[call-arg]
betas=["files-api-2025-04-14"],
)
if block_format == "anthropic":
@ -1086,7 +1086,7 @@ def test_files_api_pdf(block_format: str) -> None:
if not pdf_file_id:
pytest.skip()
llm = ChatAnthropic(
model="claude-sonnet-4-20250514",
model="claude-sonnet-4-20250514", # type: ignore[call-arg]
betas=["files-api-2025-04-14"],
)
if block_format == "anthropic":
@ -1111,7 +1111,7 @@ def test_files_api_pdf(block_format: str) -> None:
def test_search_result_tool_message() -> None:
"""Test that we can pass a search result tool message to the model."""
llm = ChatAnthropic(
model="claude-3-5-haiku-latest",
model="claude-3-5-haiku-latest", # type: ignore[call-arg]
betas=["search-results-2025-06-09"],
)
@ -1164,7 +1164,7 @@ def test_search_result_tool_message() -> None:
def test_search_result_top_level() -> None:
llm = ChatAnthropic(
model="claude-3-5-haiku-latest",
model="claude-3-5-haiku-latest", # type: ignore[call-arg]
betas=["search-results-2025-06-09"],
)
input_message = HumanMessage(
@ -1209,6 +1209,6 @@ def test_search_result_top_level() -> None:
def test_async_shared_client() -> None:
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
llm = ChatAnthropic(model="claude-3-5-haiku-latest") # type: ignore[call-arg]
_ = asyncio.run(llm.ainvoke("Hello"))
_ = asyncio.run(llm.ainvoke("Hello"))

View File

@ -24,14 +24,14 @@ def test_anthropic_model_param() -> None:
def test_anthropic_call() -> None:
"""Test valid call to anthropic."""
llm = Anthropic(model="claude-2.1") # type: ignore[call-arg]
llm = Anthropic(model="claude-3-7-sonnet-20250219") # type: ignore[call-arg]
output = llm.invoke("Say foo:")
assert isinstance(output, str)
def test_anthropic_streaming() -> None:
"""Test streaming tokens from anthropic."""
llm = Anthropic(model="claude-2.1") # type: ignore[call-arg]
llm = Anthropic(model="claude-3-7-sonnet-20250219") # type: ignore[call-arg]
generator = llm.stream("I'm Pickle Rick")
assert isinstance(generator, Generator)
@ -45,6 +45,7 @@ def test_anthropic_streaming_callback() -> None:
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = Anthropic(
model="claude-3-7-sonnet-20250219", # type: ignore[call-arg]
streaming=True,
callback_manager=callback_manager,
verbose=True,
@ -55,7 +56,7 @@ def test_anthropic_streaming_callback() -> None:
async def test_anthropic_async_generate() -> None:
"""Test async generate."""
llm = Anthropic()
llm = Anthropic(model="claude-3-7-sonnet-20250219") # type: ignore[call-arg]
output = await llm.agenerate(["How many toes do dogs have?"])
assert isinstance(output, LLMResult)
@ -65,6 +66,7 @@ async def test_anthropic_async_streaming_callback() -> None:
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = Anthropic(
model="claude-3-7-sonnet-20250219", # type: ignore[call-arg]
streaming=True,
callback_manager=callback_manager,
verbose=True,

View File

@ -505,7 +505,7 @@ typing = [
[[package]]
name = "langchain-core"
version = "0.3.68"
version = "0.3.72"
source = { editable = "../../core" }
dependencies = [
{ name = "jsonpatch" },
@ -521,7 +521,7 @@ dependencies = [
requires-dist = [
{ name = "jsonpatch", specifier = ">=1.33,<2.0" },
{ name = "langsmith", specifier = ">=0.3.45" },
{ name = "packaging", specifier = ">=23.2,<25" },
{ name = "packaging", specifier = ">=23.2" },
{ name = "pydantic", specifier = ">=2.7.4" },
{ name = "pyyaml", specifier = ">=5.3" },
{ name = "tenacity", specifier = ">=8.1.0,!=8.4.0,<10.0.0" },