From a433039a56d1f70618ae34f4bd097b25d50e9b9f Mon Sep 17 00:00:00 2001 From: ccurme Date: Fri, 22 Nov 2024 10:38:49 -0500 Subject: [PATCH] core[patch]: support final AIMessage responses in `tool_example_to_messages` (#28267) We have a test [test_structured_few_shot_examples](https://github.com/langchain-ai/langchain/blob/ad4333ca032033097c663dfe818c5c892c368bd6/libs/standard-tests/langchain_tests/integration_tests/chat_models.py#L546) in standard integration tests that implements a version of tool-calling few shot examples that works with ~all tested providers. The formulation supported by ~all providers is: `human message, tool call, tool message, AI reponse`. Here we update `langchain_core.utils.function_calling.tool_example_to_messages` to support this formulation. The `tool_example_to_messages` util is undocumented outside of our API reference. IMO, if we are testing that this function works across all providers, it can be helpful to feature it in our guides. The structured few-shot examples we document at the moment require users to implement this function and can be simplified. --- .../langchain_core/utils/function_calling.py | 17 +++++++-- .../unit_tests/utils/test_function_calling.py | 18 +++++++++ .../integration_tests/chat_models.py | 38 +++++++------------ 3 files changed, 45 insertions(+), 28 deletions(-) diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 3aff07faecd..4779d262442 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -22,7 +22,7 @@ from typing import ( from pydantic import BaseModel from typing_extensions import TypedDict, get_args, get_origin, is_typeddict -from langchain_core._api import deprecated +from langchain_core._api import beta, deprecated from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage from langchain_core.utils.json_schema import dereference_refs from langchain_core.utils.pydantic import is_basemodel_subclass @@ -494,21 +494,28 @@ def convert_to_openai_tool( return {"type": "function", "function": oai_function} +@beta() def tool_example_to_messages( - input: str, tool_calls: list[BaseModel], tool_outputs: Optional[list[str]] = None + input: str, + tool_calls: list[BaseModel], + tool_outputs: Optional[list[str]] = None, + *, + ai_response: Optional[str] = None, ) -> list[BaseMessage]: """Convert an example into a list of messages that can be fed into an LLM. This code is an adapter that converts a single example to a list of messages that can be fed into a chat model. - The list of messages per example corresponds to: + The list of messages per example by default corresponds to: 1) HumanMessage: contains the content from which content should be extracted. 2) AIMessage: contains the extracted information from the model 3) ToolMessage: contains confirmation to the model that the model requested a tool correctly. + If `ai_response` is specified, there will be a final AIMessage with that response. + The ToolMessage is required because some chat models are hyper-optimized for agents rather than for an extraction use case. @@ -519,6 +526,7 @@ def tool_example_to_messages( tool_outputs: Optional[List[str]], a list of tool call outputs. Does not need to be provided. If not provided, a placeholder value will be inserted. Defaults to None. + ai_response: Optional[str], if provided, content for a final AIMessage. Returns: A list of messages @@ -584,6 +592,9 @@ def tool_example_to_messages( ) for output, tool_call_dict in zip(tool_outputs, openai_tool_calls): messages.append(ToolMessage(content=output, tool_call_id=tool_call_dict["id"])) # type: ignore + + if ai_response: + messages.append(AIMessage(content=ai_response)) return messages diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 4eaa3da2b19..ba4c50187f1 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -679,6 +679,24 @@ def test_tool_outputs() -> None: ] assert messages[2].content == "Output1" + # Test final AI response + messages = tool_example_to_messages( + input="This is an example", + tool_calls=[ + FakeCall(data="ToolCall1"), + ], + tool_outputs=["Output1"], + ai_response="The output is Output1", + ) + assert len(messages) == 4 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert isinstance(messages[2], ToolMessage) + assert isinstance(messages[3], AIMessage) + response = messages[3] + assert response.content == "The output is Output1" + assert not response.tool_calls + @pytest.mark.parametrize("use_extension_typed_dict", [True, False]) @pytest.mark.parametrize("use_extension_annotated", [True, False]) diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index 61f45e63b9b..5ef6f99c821 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -17,6 +17,7 @@ from langchain_core.messages import ( from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.tools import tool +from langchain_core.utils.function_calling import tool_example_to_messages from pydantic import BaseModel, Field from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import Field as FieldV1 @@ -857,33 +858,20 @@ class ChatModelIntegrationTests(ChatModelTests): if not self.has_tool_calling: pytest.skip("Test requires tool calling.") model_with_tools = model.bind_tools([my_adder_tool], tool_choice="any") - function_name = "my_adder_tool" - function_args = {"a": 1, "b": 2} function_result = json.dumps({"result": 3}) - messages_string_content = [ - HumanMessage("What is 1 + 2"), - AIMessage( - "", - tool_calls=[ - { - "name": function_name, - "args": function_args, - "id": "abc123", - "type": "tool_call", - }, - ], - ), - ToolMessage( - function_result, - name=function_name, - tool_call_id="abc123", - ), - AIMessage(function_result), - HumanMessage("What is 3 + 4"), - ] - result_string_content = model_with_tools.invoke(messages_string_content) - assert isinstance(result_string_content, AIMessage) + tool_schema = my_adder_tool.args_schema + assert tool_schema is not None + few_shot_messages = tool_example_to_messages( + "What is 1 + 2", + [tool_schema(a=1, b=2)], + tool_outputs=[function_result], + ai_response=function_result, + ) + + messages = few_shot_messages + [HumanMessage("What is 3 + 4")] + result = model_with_tools.invoke(messages) + assert isinstance(result, AIMessage) def test_image_inputs(self, model: BaseChatModel) -> None: if not self.supports_image_inputs: