diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index ce553dc468f..bd18af6c670 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -43,6 +43,11 @@ from langchain_core.messages import ( ToolMessage, ) from langchain_core.messages.ai import UsageMetadata +from langchain_core.output_parsers import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, +) +from langchain_core.output_parsers.base import OutputParserLike from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.runnables import ( @@ -58,7 +63,7 @@ from langchain_core.utils import ( ) from langchain_core.utils.function_calling import convert_to_openai_tool -from langchain_anthropic.output_parsers import ToolsOutputParser, extract_tool_calls +from langchain_anthropic.output_parsers import extract_tool_calls _message_type_lookups = { "human": "user", @@ -990,11 +995,13 @@ class ChatAnthropic(BaseChatModel): tool_name = convert_to_anthropic_tool(schema)["name"] llm = self.bind_tools([schema], tool_choice=tool_name) if isinstance(schema, type) and issubclass(schema, BaseModel): - output_parser = ToolsOutputParser( - first_tool_only=True, pydantic_schemas=[schema] + output_parser: OutputParserLike = PydanticToolsParser( + tools=[schema], first_tool_only=True ) else: - output_parser = ToolsOutputParser(first_tool_only=True, args_only=True) + output_parser = JsonOutputKeyToolsParser( + key_name=tool_name, first_tool_only=True + ) if include_raw: parser_assign = RunnablePassthrough.assign( diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 14957fa59c9..53ba4ab161d 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -151,10 +151,25 @@ class ChatModelIntegrationTests(ChatModelTests): setup: str = Field(description="question to set up a joke") punchline: str = Field(description="answer to resolve the joke") + # Pydantic class chat = model.with_structured_output(Joke) result = chat.invoke("Tell me a joke about cats.") assert isinstance(result, Joke) + for chunk in chat.stream("Tell me a joke about cats."): + assert isinstance(chunk, Joke) + + # Schema + chat = model.with_structured_output(Joke.schema()) + result = chat.invoke("Tell me a joke about cats.") + assert isinstance(result, dict) + assert set(result.keys()) == {"setup", "punchline"} + + for chunk in chat.stream("Tell me a joke about cats."): + assert isinstance(chunk, dict) + assert isinstance(chunk, dict) # for mypy + assert set(chunk.keys()) == {"setup", "punchline"} + def test_tool_message_histories_string_content( self, model: BaseChatModel,