From b40c80007f344e4c6112ed7fc3a6aa1446333da8 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 22 Mar 2024 10:17:40 -0700 Subject: [PATCH] core[minor]: Add utility code to create tool examples (#18602) Co-authored-by: Chester Curme --- .../langchain_core/utils/function_calling.py | 100 ++++++++++++++++++ .../unit_tests/utils/test_function_calling.py | 77 +++++++++++++- 2 files changed, 176 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 5977615654c..a406b87097e 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -3,6 +3,7 @@ from __future__ import annotations import inspect +import uuid from typing import ( TYPE_CHECKING, Any, @@ -20,6 +21,12 @@ from typing import ( from typing_extensions import TypedDict from langchain_core._api import deprecated +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + ToolMessage, +) from langchain_core.pydantic_v1 import BaseModel from langchain_core.utils.json_schema import dereference_refs @@ -332,3 +339,96 @@ def convert_to_openai_tool( return tool function = convert_to_openai_function(tool) return {"type": "function", "function": function} + + +def tool_example_to_messages( + input: str, tool_calls: List[BaseModel], tool_outputs: Optional[List[str]] = None +) -> List[BaseMessage]: + """Convert an example into a list of messages that can be fed into an LLM. + + This code is an adapter that converts a single example to a list of messages + that can be fed into a chat model. + + The list of messages per example corresponds to: + + 1) HumanMessage: contains the content from which content should be extracted. + 2) AIMessage: contains the extracted information from the model + 3) ToolMessage: contains confirmation to the model that the model requested a tool + correctly. + + The ToolMessage is required because some chat models are hyper-optimized for agents + rather than for an extraction use case. + + Arguments: + input: string, the user input + tool_calls: List[BaseModel], a list of tool calls represented as Pydantic + BaseModels + tool_outputs: Optional[List[str]], a list of tool call outputs. + Does not need to be provided. If not provided, a placeholder value + will be inserted. + + Returns: + A list of messages + + Examples: + + .. code-block:: python + + from typing import List, Optional + from langchain_core.pydantic_v1 import BaseModel, Field + from langchain_openai import ChatOpenAI + + class Person(BaseModel): + '''Information about a person.''' + name: Optional[str] = Field(..., description="The name of the person") + hair_color: Optional[str] = Field( + ..., description="The color of the peron's eyes if known" + ) + height_in_meters: Optional[str] = Field( + ..., description="Height in METERs" + ) + + examples = [ + ( + "The ocean is vast and blue. It's more than 20,000 feet deep.", + Person(name=None, height_in_meters=None, hair_color=None), + ), + ( + "Fiona traveled far from France to Spain.", + Person(name="Fiona", height_in_meters=None, hair_color=None), + ), + ] + + + messages = [] + + for txt, tool_call in examples: + messages.extend( + tool_example_to_messages(txt, [tool_call]) + ) + """ + messages: List[BaseMessage] = [HumanMessage(content=input)] + openai_tool_calls = [] + for tool_call in tool_calls: + openai_tool_calls.append( + { + "id": str(uuid.uuid4()), + "type": "function", + "function": { + # The name of the function right now corresponds to the name + # of the pydantic model. This is implicit in the API right now, + # and will be improved over time. + "name": tool_call.__class__.__name__, + "arguments": tool_call.json(), + }, + } + ) + messages.append( + AIMessage(content="", additional_kwargs={"tool_calls": openai_tool_calls}) + ) + tool_outputs = tool_outputs or ["You have correctly called this tool."] * len( + openai_tool_calls + ) + for output, tool_call_dict in zip(tool_outputs, openai_tool_calls): + messages.append(ToolMessage(content=output, tool_call_id=tool_call_dict["id"])) # type: ignore + return messages diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 629cf769c55..00328bcf29b 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -2,9 +2,13 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Type import pytest +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import BaseTool, tool -from langchain_core.utils.function_calling import convert_to_openai_function +from langchain_core.utils.function_calling import ( + convert_to_openai_function, + tool_example_to_messages, +) @pytest.fixture() @@ -109,3 +113,74 @@ def test_function_optional_param() -> None: func = convert_to_openai_function(func5) req = func["parameters"]["required"] assert set(req) == {"b"} + + +class FakeCall(BaseModel): + data: str + + +def test_valid_example_conversion() -> None: + expected_messages = [ + HumanMessage(content="This is a valid example"), + AIMessage(content="", additional_kwargs={"tool_calls": []}), + ] + assert ( + tool_example_to_messages(input="This is a valid example", tool_calls=[]) + == expected_messages + ) + + +def test_multiple_tool_calls() -> None: + messages = tool_example_to_messages( + input="This is an example", + tool_calls=[ + FakeCall(data="ToolCall1"), + FakeCall(data="ToolCall2"), + FakeCall(data="ToolCall3"), + ], + ) + assert len(messages) == 5 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert isinstance(messages[2], ToolMessage) + assert isinstance(messages[3], ToolMessage) + assert isinstance(messages[4], ToolMessage) + assert messages[1].additional_kwargs["tool_calls"] == [ + { + "id": messages[2].tool_call_id, + "type": "function", + "function": {"name": "FakeCall", "arguments": '{"data": "ToolCall1"}'}, + }, + { + "id": messages[3].tool_call_id, + "type": "function", + "function": {"name": "FakeCall", "arguments": '{"data": "ToolCall2"}'}, + }, + { + "id": messages[4].tool_call_id, + "type": "function", + "function": {"name": "FakeCall", "arguments": '{"data": "ToolCall3"}'}, + }, + ] + + +def test_tool_outputs() -> None: + messages = tool_example_to_messages( + input="This is an example", + tool_calls=[ + FakeCall(data="ToolCall1"), + ], + tool_outputs=["Output1"], + ) + assert len(messages) == 3 + assert isinstance(messages[0], HumanMessage) + assert isinstance(messages[1], AIMessage) + assert isinstance(messages[2], ToolMessage) + assert messages[1].additional_kwargs["tool_calls"] == [ + { + "id": messages[2].tool_call_id, + "type": "function", + "function": {"name": "FakeCall", "arguments": '{"data": "ToolCall1"}'}, + }, + ] + assert messages[2].content == "Output1"