mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
bind_functions convenience method (#12518)
I always take 20-30 seconds to re-discover where the `convert_to_openai_function` wrapper lives in our codebase. Chat langchain [has no clue](https://smith.langchain.com/public/3989d687-18c7-4108-958e-96e88803da86/r) what to do either. There's the older `create_openai_fn_chain` , but we haven't been recommending it in LCEL. The example we show in the [cookbook](https://python.langchain.com/docs/expression_language/how_to/binding#attaching-openai-functions) is really verbose. General function calling should be as simple as possible to do, so this seems a bit more ergonomic to me (feel free to disagree). Another option would be to directly coerce directly in the class's init (or when calling invoke), if provided. I'm not 100% set against that. That approach may be too easy but not simple. This PR feels like a decent compromise between simple and easy. ``` from enum import Enum from typing import Optional from pydantic import BaseModel, Field class Category(str, Enum): """The category of the issue.""" bug = "bug" nit = "nit" improvement = "improvement" other = "other" class IssueClassification(BaseModel): """Classify an issue.""" category: Category other_description: Optional[str] = Field( description="If classified as 'other', the suggested other category" ) from langchain.chat_models import ChatOpenAI llm = ChatOpenAI().bind_functions([IssueClassification]) llm.invoke("This PR adds a convenience wrapper to the bind argument") # AIMessage(content='', additional_kwargs={'function_call': {'name': 'IssueClassification', 'arguments': '{\n "category": "improvement"\n}'}}) ```
This commit is contained in:
parent
3143324984
commit
bfd719f9d8
@ -13,6 +13,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
@ -29,8 +30,9 @@ from langchain.chat_models.base import (
|
|||||||
_generate_from_stream,
|
_generate_from_stream,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import create_base_retry_decorator
|
from langchain.llms.base import create_base_retry_decorator
|
||||||
from langchain.pydantic_v1 import Field, root_validator
|
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
from langchain.schema import ChatGeneration, ChatResult
|
from langchain.schema import ChatGeneration, ChatResult
|
||||||
|
from langchain.schema.language_model import LanguageModelInput
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -41,11 +43,13 @@ from langchain.schema.messages import (
|
|||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
)
|
)
|
||||||
from langchain.schema.output import ChatGenerationChunk
|
from langchain.schema.output import ChatGenerationChunk
|
||||||
|
from langchain.schema.runnable import Runnable
|
||||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -540,3 +544,45 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
# every reply is primed with <im_start>assistant
|
# every reply is primed with <im_start>assistant
|
||||||
num_tokens += 3
|
num_tokens += 3
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
def bind_functions(
|
||||||
|
self,
|
||||||
|
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
||||||
|
function_call: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
"""Bind functions (and other objects) to this chat model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
functions: A list of function definitions to bind to this chat model.
|
||||||
|
Can be a dictionary, pydantic model, or callable. Pydantic
|
||||||
|
models and callables will be automatically converted to
|
||||||
|
their schema dictionary representation.
|
||||||
|
function_call: Which function to require the model to call.
|
||||||
|
Must be the name of the single provided function or
|
||||||
|
"auto" to automatically determine which function to call
|
||||||
|
(if any).
|
||||||
|
kwargs: Any additional parameters to pass to the
|
||||||
|
:class:`~langchain.runnable.Runnable` constructor.
|
||||||
|
"""
|
||||||
|
from langchain.chains.openai_functions.base import convert_to_openai_function
|
||||||
|
|
||||||
|
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
|
||||||
|
function_call_ = None
|
||||||
|
if function_call is not None:
|
||||||
|
if len(formatted_functions) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"When specifying `function_call`, you must provide exactly one "
|
||||||
|
"function."
|
||||||
|
)
|
||||||
|
if formatted_functions[0]["name"] != function_call:
|
||||||
|
raise ValueError(
|
||||||
|
f"Function call {function_call} was specified, but the only "
|
||||||
|
f"provided function was {formatted_functions[0]['name']}."
|
||||||
|
)
|
||||||
|
function_call_ = {"name": function_call}
|
||||||
|
kwargs = {**kwargs, "function_call": function_call_}
|
||||||
|
return super().bind(
|
||||||
|
functions=formatted_functions,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
@ -9,7 +9,9 @@ from langchain.chains.openai_functions import (
|
|||||||
create_openai_fn_chain,
|
create_openai_fn_chain,
|
||||||
)
|
)
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
|
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
|
from langchain.pydantic_v1 import BaseModel, Field
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
ChatResult,
|
ChatResult,
|
||||||
@ -297,6 +299,46 @@ async def test_async_chat_openai_streaming_with_function() -> None:
|
|||||||
assert all([chunk is not None for chunk in callback_handler._captured_chunks])
|
assert all([chunk is not None for chunk in callback_handler._captured_chunks])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_openai_bind_functions() -> None:
|
||||||
|
"""Test ChatOpenAI wrapper with multiple completions."""
|
||||||
|
|
||||||
|
class Person(BaseModel):
|
||||||
|
"""Identifying information about a person."""
|
||||||
|
|
||||||
|
name: str = Field(..., title="Name", description="The person's name")
|
||||||
|
age: int = Field(..., title="Age", description="The person's age")
|
||||||
|
fav_food: Optional[str] = Field(
|
||||||
|
default=None, title="Fav Food", description="The person's favorite food"
|
||||||
|
)
|
||||||
|
|
||||||
|
chat = ChatOpenAI(
|
||||||
|
max_tokens=30,
|
||||||
|
n=1,
|
||||||
|
streaming=True,
|
||||||
|
).bind_functions(functions=[Person], function_call="Person")
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
("system", "Use the provided Person function"),
|
||||||
|
("user", "{input}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = prompt | chat | JsonOutputFunctionsParser(args_only=True)
|
||||||
|
|
||||||
|
message = HumanMessage(content="Sally is 13 years old")
|
||||||
|
response = await chain.abatch([{"input": message}])
|
||||||
|
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert len(response) == 1
|
||||||
|
for generation in response:
|
||||||
|
assert isinstance(generation, dict)
|
||||||
|
assert "name" in generation
|
||||||
|
assert "age" in generation
|
||||||
|
|
||||||
|
|
||||||
def test_chat_openai_extra_kwargs() -> None:
|
def test_chat_openai_extra_kwargs() -> None:
|
||||||
"""Test extra kwargs to chat openai."""
|
"""Test extra kwargs to chat openai."""
|
||||||
# Check that foo is saved in extra_kwargs.
|
# Check that foo is saved in extra_kwargs.
|
||||||
|
Loading…
Reference in New Issue
Block a user