From bfd719f9d8454c69ce4fc8df2e8b4c96cb446ce3 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Tue, 31 Oct 2023 23:15:37 +0900 Subject: [PATCH] bind_functions convenience method (#12518) I always take 20-30 seconds to re-discover where the `convert_to_openai_function` wrapper lives in our codebase. Chat langchain [has no clue](https://smith.langchain.com/public/3989d687-18c7-4108-958e-96e88803da86/r) what to do either. There's the older `create_openai_fn_chain` , but we haven't been recommending it in LCEL. The example we show in the [cookbook](https://python.langchain.com/docs/expression_language/how_to/binding#attaching-openai-functions) is really verbose. General function calling should be as simple as possible to do, so this seems a bit more ergonomic to me (feel free to disagree). Another option would be to directly coerce directly in the class's init (or when calling invoke), if provided. I'm not 100% set against that. That approach may be too easy but not simple. This PR feels like a decent compromise between simple and easy. ``` from enum import Enum from typing import Optional from pydantic import BaseModel, Field class Category(str, Enum): """The category of the issue.""" bug = "bug" nit = "nit" improvement = "improvement" other = "other" class IssueClassification(BaseModel): """Classify an issue.""" category: Category other_description: Optional[str] = Field( description="If classified as 'other', the suggested other category" ) from langchain.chat_models import ChatOpenAI llm = ChatOpenAI().bind_functions([IssueClassification]) llm.invoke("This PR adds a convenience wrapper to the bind argument") # AIMessage(content='', additional_kwargs={'function_call': {'name': 'IssueClassification', 'arguments': '{\n "category": "improvement"\n}'}}) ``` --- .../langchain/langchain/chat_models/openai.py | 48 ++++++++++++++++++- .../chat_models/test_openai.py | 42 ++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/chat_models/openai.py b/libs/langchain/langchain/chat_models/openai.py index a09dfd3e9cb..d2993c13baf 100644 --- a/libs/langchain/langchain/chat_models/openai.py +++ b/libs/langchain/langchain/chat_models/openai.py @@ -13,6 +13,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Tuple, Type, Union, @@ -29,8 +30,9 @@ from langchain.chat_models.base import ( _generate_from_stream, ) from langchain.llms.base import create_base_retry_decorator -from langchain.pydantic_v1 import Field, root_validator +from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.schema import ChatGeneration, ChatResult +from langchain.schema.language_model import LanguageModelInput from langchain.schema.messages import ( AIMessageChunk, BaseMessage, @@ -41,11 +43,13 @@ from langchain.schema.messages import ( SystemMessageChunk, ) from langchain.schema.output import ChatGenerationChunk +from langchain.schema.runnable import Runnable from langchain.utils import get_from_dict_or_env, get_pydantic_field_names if TYPE_CHECKING: import tiktoken + logger = logging.getLogger(__name__) @@ -540,3 +544,45 @@ class ChatOpenAI(BaseChatModel): # every reply is primed with assistant num_tokens += 3 return num_tokens + + def bind_functions( + self, + functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], + function_call: Optional[str] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind functions (and other objects) to this chat model. + + Args: + functions: A list of function definitions to bind to this chat model. + Can be a dictionary, pydantic model, or callable. Pydantic + models and callables will be automatically converted to + their schema dictionary representation. + function_call: Which function to require the model to call. + Must be the name of the single provided function or + "auto" to automatically determine which function to call + (if any). + kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + from langchain.chains.openai_functions.base import convert_to_openai_function + + formatted_functions = [convert_to_openai_function(fn) for fn in functions] + function_call_ = None + if function_call is not None: + if len(formatted_functions) != 1: + raise ValueError( + "When specifying `function_call`, you must provide exactly one " + "function." + ) + if formatted_functions[0]["name"] != function_call: + raise ValueError( + f"Function call {function_call} was specified, but the only " + f"provided function was {formatted_functions[0]['name']}." + ) + function_call_ = {"name": function_call} + kwargs = {**kwargs, "function_call": function_call_} + return super().bind( + functions=formatted_functions, + **kwargs, + ) diff --git a/libs/langchain/tests/integration_tests/chat_models/test_openai.py b/libs/langchain/tests/integration_tests/chat_models/test_openai.py index 5c8b0e43e6e..e1da41c384c 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_openai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_openai.py @@ -9,7 +9,9 @@ from langchain.chains.openai_functions import ( create_openai_fn_chain, ) from langchain.chat_models.openai import ChatOpenAI +from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain.pydantic_v1 import BaseModel, Field from langchain.schema import ( ChatGeneration, ChatResult, @@ -297,6 +299,46 @@ async def test_async_chat_openai_streaming_with_function() -> None: assert all([chunk is not None for chunk in callback_handler._captured_chunks]) +@pytest.mark.scheduled +@pytest.mark.asyncio +async def test_async_chat_openai_bind_functions() -> None: + """Test ChatOpenAI wrapper with multiple completions.""" + + class Person(BaseModel): + """Identifying information about a person.""" + + name: str = Field(..., title="Name", description="The person's name") + age: int = Field(..., title="Age", description="The person's age") + fav_food: Optional[str] = Field( + default=None, title="Fav Food", description="The person's favorite food" + ) + + chat = ChatOpenAI( + max_tokens=30, + n=1, + streaming=True, + ).bind_functions(functions=[Person], function_call="Person") + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "Use the provided Person function"), + ("user", "{input}"), + ] + ) + + chain = prompt | chat | JsonOutputFunctionsParser(args_only=True) + + message = HumanMessage(content="Sally is 13 years old") + response = await chain.abatch([{"input": message}]) + + assert isinstance(response, list) + assert len(response) == 1 + for generation in response: + assert isinstance(generation, dict) + assert "name" in generation + assert "age" in generation + + def test_chat_openai_extra_kwargs() -> None: """Test extra kwargs to chat openai.""" # Check that foo is saved in extra_kwargs.