mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
bind_functions convenience method (#12518)
I always take 20-30 seconds to re-discover where the `convert_to_openai_function` wrapper lives in our codebase. Chat langchain [has no clue](https://smith.langchain.com/public/3989d687-18c7-4108-958e-96e88803da86/r) what to do either. There's the older `create_openai_fn_chain` , but we haven't been recommending it in LCEL. The example we show in the [cookbook](https://python.langchain.com/docs/expression_language/how_to/binding#attaching-openai-functions) is really verbose. General function calling should be as simple as possible to do, so this seems a bit more ergonomic to me (feel free to disagree). Another option would be to directly coerce directly in the class's init (or when calling invoke), if provided. I'm not 100% set against that. That approach may be too easy but not simple. This PR feels like a decent compromise between simple and easy. ``` from enum import Enum from typing import Optional from pydantic import BaseModel, Field class Category(str, Enum): """The category of the issue.""" bug = "bug" nit = "nit" improvement = "improvement" other = "other" class IssueClassification(BaseModel): """Classify an issue.""" category: Category other_description: Optional[str] = Field( description="If classified as 'other', the suggested other category" ) from langchain.chat_models import ChatOpenAI llm = ChatOpenAI().bind_functions([IssueClassification]) llm.invoke("This PR adds a convenience wrapper to the bind argument") # AIMessage(content='', additional_kwargs={'function_call': {'name': 'IssueClassification', 'arguments': '{\n "category": "improvement"\n}'}}) ```
This commit is contained in:
parent
3143324984
commit
bfd719f9d8
@ -13,6 +13,7 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
@ -29,8 +30,9 @@ from langchain.chat_models.base import (
|
||||
_generate_from_stream,
|
||||
)
|
||||
from langchain.llms.base import create_base_retry_decorator
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.messages import (
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
@ -41,11 +43,13 @@ from langchain.schema.messages import (
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import tiktoken
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -540,3 +544,45 @@ class ChatOpenAI(BaseChatModel):
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
|
||||
def bind_functions(
|
||||
self,
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
||||
function_call: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind functions (and other objects) to this chat model.
|
||||
|
||||
Args:
|
||||
functions: A list of function definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, or callable. Pydantic
|
||||
models and callables will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
function_call: Which function to require the model to call.
|
||||
Must be the name of the single provided function or
|
||||
"auto" to automatically determine which function to call
|
||||
(if any).
|
||||
kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
from langchain.chains.openai_functions.base import convert_to_openai_function
|
||||
|
||||
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
|
||||
function_call_ = None
|
||||
if function_call is not None:
|
||||
if len(formatted_functions) != 1:
|
||||
raise ValueError(
|
||||
"When specifying `function_call`, you must provide exactly one "
|
||||
"function."
|
||||
)
|
||||
if formatted_functions[0]["name"] != function_call:
|
||||
raise ValueError(
|
||||
f"Function call {function_call} was specified, but the only "
|
||||
f"provided function was {formatted_functions[0]['name']}."
|
||||
)
|
||||
function_call_ = {"name": function_call}
|
||||
kwargs = {**kwargs, "function_call": function_call_}
|
||||
return super().bind(
|
||||
functions=formatted_functions,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -9,7 +9,9 @@ from langchain.chains.openai_functions import (
|
||||
create_openai_fn_chain,
|
||||
)
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
@ -297,6 +299,46 @@ async def test_async_chat_openai_streaming_with_function() -> None:
|
||||
assert all([chunk is not None for chunk in callback_handler._captured_chunks])
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_openai_bind_functions() -> None:
|
||||
"""Test ChatOpenAI wrapper with multiple completions."""
|
||||
|
||||
class Person(BaseModel):
|
||||
"""Identifying information about a person."""
|
||||
|
||||
name: str = Field(..., title="Name", description="The person's name")
|
||||
age: int = Field(..., title="Age", description="The person's age")
|
||||
fav_food: Optional[str] = Field(
|
||||
default=None, title="Fav Food", description="The person's favorite food"
|
||||
)
|
||||
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=30,
|
||||
n=1,
|
||||
streaming=True,
|
||||
).bind_functions(functions=[Person], function_call="Person")
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "Use the provided Person function"),
|
||||
("user", "{input}"),
|
||||
]
|
||||
)
|
||||
|
||||
chain = prompt | chat | JsonOutputFunctionsParser(args_only=True)
|
||||
|
||||
message = HumanMessage(content="Sally is 13 years old")
|
||||
response = await chain.abatch([{"input": message}])
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
for generation in response:
|
||||
assert isinstance(generation, dict)
|
||||
assert "name" in generation
|
||||
assert "age" in generation
|
||||
|
||||
|
||||
def test_chat_openai_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat openai."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
|
Loading…
Reference in New Issue
Block a user